Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. from typing import Literal
  2. from flask import request
  3. from flask_restx import marshal, reqparse
  4. from werkzeug.exceptions import Forbidden, NotFound
  5. import services.dataset_service
  6. from controllers.service_api import service_api_ns
  7. from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
  8. from controllers.service_api.wraps import (
  9. DatasetApiResource,
  10. cloud_edition_billing_rate_limit_check,
  11. validate_dataset_token,
  12. )
  13. from core.model_runtime.entities.model_entities import ModelType
  14. from core.plugin.entities.plugin import ModelProviderID
  15. from core.provider_manager import ProviderManager
  16. from fields.dataset_fields import dataset_detail_fields
  17. from fields.tag_fields import build_dataset_tag_fields
  18. from libs.login import current_user
  19. from models.dataset import Dataset, DatasetPermissionEnum
  20. from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
  21. from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
  22. from services.tag_service import TagService
  23. def _validate_name(name):
  24. if not name or len(name) < 1 or len(name) > 40:
  25. raise ValueError("Name must be between 1 to 40 characters.")
  26. return name
  27. def _validate_description_length(description):
  28. if description and len(description) > 400:
  29. raise ValueError("Description cannot exceed 400 characters.")
  30. return description
  31. # Define parsers for dataset operations
  32. dataset_create_parser = reqparse.RequestParser()
  33. dataset_create_parser.add_argument(
  34. "name",
  35. nullable=False,
  36. required=True,
  37. help="type is required. Name must be between 1 to 40 characters.",
  38. type=_validate_name,
  39. )
  40. dataset_create_parser.add_argument(
  41. "description",
  42. type=_validate_description_length,
  43. nullable=True,
  44. required=False,
  45. default="",
  46. )
  47. dataset_create_parser.add_argument(
  48. "indexing_technique",
  49. type=str,
  50. location="json",
  51. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  52. help="Invalid indexing technique.",
  53. )
  54. dataset_create_parser.add_argument(
  55. "permission",
  56. type=str,
  57. location="json",
  58. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  59. help="Invalid permission.",
  60. required=False,
  61. nullable=False,
  62. )
  63. dataset_create_parser.add_argument(
  64. "external_knowledge_api_id",
  65. type=str,
  66. nullable=True,
  67. required=False,
  68. default="_validate_name",
  69. )
  70. dataset_create_parser.add_argument(
  71. "provider",
  72. type=str,
  73. nullable=True,
  74. required=False,
  75. default="vendor",
  76. )
  77. dataset_create_parser.add_argument(
  78. "external_knowledge_id",
  79. type=str,
  80. nullable=True,
  81. required=False,
  82. )
  83. dataset_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
  84. dataset_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
  85. dataset_create_parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
  86. dataset_update_parser = reqparse.RequestParser()
  87. dataset_update_parser.add_argument(
  88. "name",
  89. nullable=False,
  90. help="type is required. Name must be between 1 to 40 characters.",
  91. type=_validate_name,
  92. )
  93. dataset_update_parser.add_argument(
  94. "description", location="json", store_missing=False, type=_validate_description_length
  95. )
  96. dataset_update_parser.add_argument(
  97. "indexing_technique",
  98. type=str,
  99. location="json",
  100. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  101. nullable=True,
  102. help="Invalid indexing technique.",
  103. )
  104. dataset_update_parser.add_argument(
  105. "permission",
  106. type=str,
  107. location="json",
  108. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  109. help="Invalid permission.",
  110. )
  111. dataset_update_parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
  112. dataset_update_parser.add_argument(
  113. "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
  114. )
  115. dataset_update_parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
  116. dataset_update_parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
  117. dataset_update_parser.add_argument(
  118. "external_retrieval_model",
  119. type=dict,
  120. required=False,
  121. nullable=True,
  122. location="json",
  123. help="Invalid external retrieval model.",
  124. )
  125. dataset_update_parser.add_argument(
  126. "external_knowledge_id",
  127. type=str,
  128. required=False,
  129. nullable=True,
  130. location="json",
  131. help="Invalid external knowledge id.",
  132. )
  133. dataset_update_parser.add_argument(
  134. "external_knowledge_api_id",
  135. type=str,
  136. required=False,
  137. nullable=True,
  138. location="json",
  139. help="Invalid external knowledge api id.",
  140. )
  141. tag_create_parser = reqparse.RequestParser()
  142. tag_create_parser.add_argument(
  143. "name",
  144. nullable=False,
  145. required=True,
  146. help="Name must be between 1 to 50 characters.",
  147. type=lambda x: x
  148. if x and 1 <= len(x) <= 50
  149. else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
  150. )
  151. tag_update_parser = reqparse.RequestParser()
  152. tag_update_parser.add_argument(
  153. "name",
  154. nullable=False,
  155. required=True,
  156. help="Name must be between 1 to 50 characters.",
  157. type=lambda x: x
  158. if x and 1 <= len(x) <= 50
  159. else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
  160. )
  161. tag_update_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
  162. tag_delete_parser = reqparse.RequestParser()
  163. tag_delete_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
  164. tag_binding_parser = reqparse.RequestParser()
  165. tag_binding_parser.add_argument(
  166. "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
  167. )
  168. tag_binding_parser.add_argument(
  169. "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
  170. )
  171. tag_unbinding_parser = reqparse.RequestParser()
  172. tag_unbinding_parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
  173. tag_unbinding_parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
  174. @service_api_ns.route("/datasets")
  175. class DatasetListApi(DatasetApiResource):
  176. """Resource for datasets."""
  177. @service_api_ns.doc("list_datasets")
  178. @service_api_ns.doc(description="List all datasets")
  179. @service_api_ns.doc(
  180. responses={
  181. 200: "Datasets retrieved successfully",
  182. 401: "Unauthorized - invalid API token",
  183. }
  184. )
  185. def get(self, tenant_id):
  186. """Resource for getting datasets."""
  187. page = request.args.get("page", default=1, type=int)
  188. limit = request.args.get("limit", default=20, type=int)
  189. # provider = request.args.get("provider", default="vendor")
  190. search = request.args.get("keyword", default=None, type=str)
  191. tag_ids = request.args.getlist("tag_ids")
  192. include_all = request.args.get("include_all", default="false").lower() == "true"
  193. datasets, total = DatasetService.get_datasets(
  194. page, limit, tenant_id, current_user, search, tag_ids, include_all
  195. )
  196. # check embedding setting
  197. provider_manager = ProviderManager()
  198. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  199. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  200. model_names = []
  201. for embedding_model in embedding_models:
  202. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  203. data = marshal(datasets, dataset_detail_fields)
  204. for item in data:
  205. if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
  206. item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
  207. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  208. if item_model in model_names:
  209. item["embedding_available"] = True
  210. else:
  211. item["embedding_available"] = False
  212. else:
  213. item["embedding_available"] = True
  214. response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
  215. return response, 200
  216. @service_api_ns.expect(dataset_create_parser)
  217. @service_api_ns.doc("create_dataset")
  218. @service_api_ns.doc(description="Create a new dataset")
  219. @service_api_ns.doc(
  220. responses={
  221. 200: "Dataset created successfully",
  222. 401: "Unauthorized - invalid API token",
  223. 400: "Bad request - invalid parameters",
  224. }
  225. )
  226. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  227. def post(self, tenant_id):
  228. """Resource for creating datasets."""
  229. args = dataset_create_parser.parse_args()
  230. if args.get("embedding_model_provider"):
  231. DatasetService.check_embedding_model_setting(
  232. tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
  233. )
  234. if (
  235. args.get("retrieval_model")
  236. and args.get("retrieval_model").get("reranking_model")
  237. and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
  238. ):
  239. DatasetService.check_reranking_model_setting(
  240. tenant_id,
  241. args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
  242. args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
  243. )
  244. try:
  245. dataset = DatasetService.create_empty_dataset(
  246. tenant_id=tenant_id,
  247. name=args["name"],
  248. description=args["description"],
  249. indexing_technique=args["indexing_technique"],
  250. account=current_user,
  251. permission=args["permission"],
  252. provider=args["provider"],
  253. external_knowledge_api_id=args["external_knowledge_api_id"],
  254. external_knowledge_id=args["external_knowledge_id"],
  255. embedding_model_provider=args["embedding_model_provider"],
  256. embedding_model_name=args["embedding_model"],
  257. retrieval_model=RetrievalModel(**args["retrieval_model"])
  258. if args["retrieval_model"] is not None
  259. else None,
  260. )
  261. except services.errors.dataset.DatasetNameDuplicateError:
  262. raise DatasetNameDuplicateError()
  263. return marshal(dataset, dataset_detail_fields), 200
  264. @service_api_ns.route("/datasets/<uuid:dataset_id>")
  265. class DatasetApi(DatasetApiResource):
  266. """Resource for dataset."""
  267. @service_api_ns.doc("get_dataset")
  268. @service_api_ns.doc(description="Get a specific dataset by ID")
  269. @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
  270. @service_api_ns.doc(
  271. responses={
  272. 200: "Dataset retrieved successfully",
  273. 401: "Unauthorized - invalid API token",
  274. 403: "Forbidden - insufficient permissions",
  275. 404: "Dataset not found",
  276. }
  277. )
  278. def get(self, _, dataset_id):
  279. dataset_id_str = str(dataset_id)
  280. dataset = DatasetService.get_dataset(dataset_id_str)
  281. if dataset is None:
  282. raise NotFound("Dataset not found.")
  283. try:
  284. DatasetService.check_dataset_permission(dataset, current_user)
  285. except services.errors.account.NoPermissionError as e:
  286. raise Forbidden(str(e))
  287. data = marshal(dataset, dataset_detail_fields)
  288. if data.get("permission") == "partial_members":
  289. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  290. data.update({"partial_member_list": part_users_list})
  291. # check embedding setting
  292. provider_manager = ProviderManager()
  293. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  294. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  295. model_names = []
  296. for embedding_model in embedding_models:
  297. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  298. if data["indexing_technique"] == "high_quality":
  299. item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
  300. if item_model in model_names:
  301. data["embedding_available"] = True
  302. else:
  303. data["embedding_available"] = False
  304. else:
  305. data["embedding_available"] = True
  306. if data.get("permission") == "partial_members":
  307. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  308. data.update({"partial_member_list": part_users_list})
  309. return data, 200
  310. @service_api_ns.expect(dataset_update_parser)
  311. @service_api_ns.doc("update_dataset")
  312. @service_api_ns.doc(description="Update an existing dataset")
  313. @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
  314. @service_api_ns.doc(
  315. responses={
  316. 200: "Dataset updated successfully",
  317. 401: "Unauthorized - invalid API token",
  318. 403: "Forbidden - insufficient permissions",
  319. 404: "Dataset not found",
  320. }
  321. )
  322. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  323. def patch(self, _, dataset_id):
  324. dataset_id_str = str(dataset_id)
  325. dataset = DatasetService.get_dataset(dataset_id_str)
  326. if dataset is None:
  327. raise NotFound("Dataset not found.")
  328. args = dataset_update_parser.parse_args()
  329. data = request.get_json()
  330. # check embedding model setting
  331. if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
  332. DatasetService.check_embedding_model_setting(
  333. dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
  334. )
  335. if (
  336. data.get("retrieval_model")
  337. and data.get("retrieval_model").get("reranking_model")
  338. and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
  339. ):
  340. DatasetService.check_reranking_model_setting(
  341. dataset.tenant_id,
  342. data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
  343. data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
  344. )
  345. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  346. DatasetPermissionService.check_permission(
  347. current_user, dataset, data.get("permission"), data.get("partial_member_list")
  348. )
  349. dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
  350. if dataset is None:
  351. raise NotFound("Dataset not found.")
  352. result_data = marshal(dataset, dataset_detail_fields)
  353. tenant_id = current_user.current_tenant_id
  354. if data.get("partial_member_list") and data.get("permission") == "partial_members":
  355. DatasetPermissionService.update_partial_member_list(
  356. tenant_id, dataset_id_str, data.get("partial_member_list")
  357. )
  358. # clear partial member list when permission is only_me or all_team_members
  359. elif (
  360. data.get("permission") == DatasetPermissionEnum.ONLY_ME
  361. or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
  362. ):
  363. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  364. partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  365. result_data.update({"partial_member_list": partial_member_list})
  366. return result_data, 200
  367. @service_api_ns.doc("delete_dataset")
  368. @service_api_ns.doc(description="Delete a dataset")
  369. @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
  370. @service_api_ns.doc(
  371. responses={
  372. 204: "Dataset deleted successfully",
  373. 401: "Unauthorized - invalid API token",
  374. 404: "Dataset not found",
  375. 409: "Conflict - dataset is in use",
  376. }
  377. )
  378. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  379. def delete(self, _, dataset_id):
  380. """
  381. Deletes a dataset given its ID.
  382. Args:
  383. _: ignore
  384. dataset_id (UUID): The ID of the dataset to be deleted.
  385. Returns:
  386. dict: A dictionary with a key 'result' and a value 'success'
  387. if the dataset was successfully deleted. Omitted in HTTP response.
  388. int: HTTP status code 204 indicating that the operation was successful.
  389. Raises:
  390. NotFound: If the dataset with the given ID does not exist.
  391. """
  392. dataset_id_str = str(dataset_id)
  393. try:
  394. if DatasetService.delete_dataset(dataset_id_str, current_user):
  395. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  396. return 204
  397. else:
  398. raise NotFound("Dataset not found.")
  399. except services.errors.dataset.DatasetInUseError:
  400. raise DatasetInUseError()
  401. @service_api_ns.route("/datasets/<uuid:dataset_id>/documents/status/<string:action>")
  402. class DocumentStatusApi(DatasetApiResource):
  403. """Resource for batch document status operations."""
  404. @service_api_ns.doc("update_document_status")
  405. @service_api_ns.doc(description="Batch update document status")
  406. @service_api_ns.doc(
  407. params={
  408. "dataset_id": "Dataset ID",
  409. "action": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'",
  410. }
  411. )
  412. @service_api_ns.doc(
  413. responses={
  414. 200: "Document status updated successfully",
  415. 401: "Unauthorized - invalid API token",
  416. 403: "Forbidden - insufficient permissions",
  417. 404: "Dataset not found",
  418. 400: "Bad request - invalid action",
  419. }
  420. )
  421. def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
  422. """
  423. Batch update document status.
  424. Args:
  425. tenant_id: tenant id
  426. dataset_id: dataset id
  427. action: action to perform (Literal["enable", "disable", "archive", "un_archive"])
  428. Returns:
  429. dict: A dictionary with a key 'result' and a value 'success'
  430. int: HTTP status code 200 indicating that the operation was successful.
  431. Raises:
  432. NotFound: If the dataset with the given ID does not exist.
  433. Forbidden: If the user does not have permission.
  434. InvalidActionError: If the action is invalid or cannot be performed.
  435. """
  436. dataset_id_str = str(dataset_id)
  437. dataset = DatasetService.get_dataset(dataset_id_str)
  438. if dataset is None:
  439. raise NotFound("Dataset not found.")
  440. # Check user's permission
  441. try:
  442. DatasetService.check_dataset_permission(dataset, current_user)
  443. except services.errors.account.NoPermissionError as e:
  444. raise Forbidden(str(e))
  445. # Check dataset model setting
  446. DatasetService.check_dataset_model_setting(dataset)
  447. # Get document IDs from request body
  448. data = request.get_json()
  449. document_ids = data.get("document_ids", [])
  450. try:
  451. DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
  452. except services.errors.document.DocumentIndexingError as e:
  453. raise InvalidActionError(str(e))
  454. except ValueError as e:
  455. raise InvalidActionError(str(e))
  456. return {"result": "success"}, 200
  457. @service_api_ns.route("/datasets/tags")
  458. class DatasetTagsApi(DatasetApiResource):
  459. @service_api_ns.doc("list_dataset_tags")
  460. @service_api_ns.doc(description="Get all knowledge type tags")
  461. @service_api_ns.doc(
  462. responses={
  463. 200: "Tags retrieved successfully",
  464. 401: "Unauthorized - invalid API token",
  465. }
  466. )
  467. @validate_dataset_token
  468. @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
  469. def get(self, _, dataset_id):
  470. """Get all knowledge type tags."""
  471. tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
  472. return tags, 200
  473. @service_api_ns.expect(tag_create_parser)
  474. @service_api_ns.doc("create_dataset_tag")
  475. @service_api_ns.doc(description="Add a knowledge type tag")
  476. @service_api_ns.doc(
  477. responses={
  478. 200: "Tag created successfully",
  479. 401: "Unauthorized - invalid API token",
  480. 403: "Forbidden - insufficient permissions",
  481. }
  482. )
  483. @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
  484. @validate_dataset_token
  485. def post(self, _, dataset_id):
  486. """Add a knowledge type tag."""
  487. if not (current_user.is_editor or current_user.is_dataset_editor):
  488. raise Forbidden()
  489. args = tag_create_parser.parse_args()
  490. args["type"] = "knowledge"
  491. tag = TagService.save_tags(args)
  492. response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
  493. return response, 200
  494. @service_api_ns.expect(tag_update_parser)
  495. @service_api_ns.doc("update_dataset_tag")
  496. @service_api_ns.doc(description="Update a knowledge type tag")
  497. @service_api_ns.doc(
  498. responses={
  499. 200: "Tag updated successfully",
  500. 401: "Unauthorized - invalid API token",
  501. 403: "Forbidden - insufficient permissions",
  502. }
  503. )
  504. @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
  505. @validate_dataset_token
  506. def patch(self, _, dataset_id):
  507. if not (current_user.is_editor or current_user.is_dataset_editor):
  508. raise Forbidden()
  509. args = tag_update_parser.parse_args()
  510. args["type"] = "knowledge"
  511. tag = TagService.update_tags(args, args.get("tag_id"))
  512. binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
  513. response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
  514. return response, 200
  515. @service_api_ns.expect(tag_delete_parser)
  516. @service_api_ns.doc("delete_dataset_tag")
  517. @service_api_ns.doc(description="Delete a knowledge type tag")
  518. @service_api_ns.doc(
  519. responses={
  520. 204: "Tag deleted successfully",
  521. 401: "Unauthorized - invalid API token",
  522. 403: "Forbidden - insufficient permissions",
  523. }
  524. )
  525. @validate_dataset_token
  526. def delete(self, _, dataset_id):
  527. """Delete a knowledge type tag."""
  528. if not current_user.is_editor:
  529. raise Forbidden()
  530. args = tag_delete_parser.parse_args()
  531. TagService.delete_tag(args.get("tag_id"))
  532. return 204
  533. @service_api_ns.route("/datasets/tags/binding")
  534. class DatasetTagBindingApi(DatasetApiResource):
  535. @service_api_ns.expect(tag_binding_parser)
  536. @service_api_ns.doc("bind_dataset_tags")
  537. @service_api_ns.doc(description="Bind tags to a dataset")
  538. @service_api_ns.doc(
  539. responses={
  540. 204: "Tags bound successfully",
  541. 401: "Unauthorized - invalid API token",
  542. 403: "Forbidden - insufficient permissions",
  543. }
  544. )
  545. @validate_dataset_token
  546. def post(self, _, dataset_id):
  547. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  548. if not (current_user.is_editor or current_user.is_dataset_editor):
  549. raise Forbidden()
  550. args = tag_binding_parser.parse_args()
  551. args["type"] = "knowledge"
  552. TagService.save_tag_binding(args)
  553. return 204
  554. @service_api_ns.route("/datasets/tags/unbinding")
  555. class DatasetTagUnbindingApi(DatasetApiResource):
  556. @service_api_ns.expect(tag_unbinding_parser)
  557. @service_api_ns.doc("unbind_dataset_tag")
  558. @service_api_ns.doc(description="Unbind a tag from a dataset")
  559. @service_api_ns.doc(
  560. responses={
  561. 204: "Tag unbound successfully",
  562. 401: "Unauthorized - invalid API token",
  563. 403: "Forbidden - insufficient permissions",
  564. }
  565. )
  566. @validate_dataset_token
  567. def post(self, _, dataset_id):
  568. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  569. if not (current_user.is_editor or current_user.is_dataset_editor):
  570. raise Forbidden()
  571. args = tag_unbinding_parser.parse_args()
  572. args["type"] = "knowledge"
  573. TagService.delete_tag_binding(args)
  574. return 204
  575. @service_api_ns.route("/datasets/<uuid:dataset_id>/tags")
  576. class DatasetTagsBindingStatusApi(DatasetApiResource):
  577. @service_api_ns.doc("get_dataset_tags_binding_status")
  578. @service_api_ns.doc(description="Get tags bound to a specific dataset")
  579. @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
  580. @service_api_ns.doc(
  581. responses={
  582. 200: "Tags retrieved successfully",
  583. 401: "Unauthorized - invalid API token",
  584. }
  585. )
  586. @validate_dataset_token
  587. def get(self, _, *args, **kwargs):
  588. """Get all knowledge type tags."""
  589. dataset_id = kwargs.get("dataset_id")
  590. tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
  591. tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
  592. response = {"data": tags_list, "total": len(tags)}
  593. return response, 200