選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

dataset.py 27KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. from typing import Any, Literal, cast
  2. from flask import request
  3. from flask_restx import marshal, reqparse
  4. from werkzeug.exceptions import Forbidden, NotFound
  5. import services
  6. from controllers.service_api import service_api_ns
  7. from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
  8. from controllers.service_api.wraps import (
  9. DatasetApiResource,
  10. cloud_edition_billing_rate_limit_check,
  11. validate_dataset_token,
  12. )
  13. from core.model_runtime.entities.model_entities import ModelType
  14. from core.provider_manager import ProviderManager
  15. from fields.dataset_fields import dataset_detail_fields
  16. from fields.tag_fields import build_dataset_tag_fields
  17. from libs.login import current_user
  18. from models.account import Account
  19. from models.dataset import Dataset, DatasetPermissionEnum
  20. from models.provider_ids import ModelProviderID
  21. from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
  22. from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
  23. from services.tag_service import TagService
  24. def _validate_name(name):
  25. if not name or len(name) < 1 or len(name) > 40:
  26. raise ValueError("Name must be between 1 to 40 characters.")
  27. return name
  28. def _validate_description_length(description):
  29. if description and len(description) > 400:
  30. raise ValueError("Description cannot exceed 400 characters.")
  31. return description
  32. # Define parsers for dataset operations
  33. dataset_create_parser = reqparse.RequestParser()
  34. dataset_create_parser.add_argument(
  35. "name",
  36. nullable=False,
  37. required=True,
  38. help="type is required. Name must be between 1 to 40 characters.",
  39. type=_validate_name,
  40. )
  41. dataset_create_parser.add_argument(
  42. "description",
  43. type=_validate_description_length,
  44. nullable=True,
  45. required=False,
  46. default="",
  47. )
  48. dataset_create_parser.add_argument(
  49. "indexing_technique",
  50. type=str,
  51. location="json",
  52. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  53. help="Invalid indexing technique.",
  54. )
  55. dataset_create_parser.add_argument(
  56. "permission",
  57. type=str,
  58. location="json",
  59. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  60. help="Invalid permission.",
  61. required=False,
  62. nullable=False,
  63. )
  64. dataset_create_parser.add_argument(
  65. "external_knowledge_api_id",
  66. type=str,
  67. nullable=True,
  68. required=False,
  69. default="_validate_name",
  70. )
  71. dataset_create_parser.add_argument(
  72. "provider",
  73. type=str,
  74. nullable=True,
  75. required=False,
  76. default="vendor",
  77. )
  78. dataset_create_parser.add_argument(
  79. "external_knowledge_id",
  80. type=str,
  81. nullable=True,
  82. required=False,
  83. )
  84. dataset_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
  85. dataset_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
  86. dataset_create_parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
  87. dataset_update_parser = reqparse.RequestParser()
  88. dataset_update_parser.add_argument(
  89. "name",
  90. nullable=False,
  91. help="type is required. Name must be between 1 to 40 characters.",
  92. type=_validate_name,
  93. )
  94. dataset_update_parser.add_argument(
  95. "description", location="json", store_missing=False, type=_validate_description_length
  96. )
  97. dataset_update_parser.add_argument(
  98. "indexing_technique",
  99. type=str,
  100. location="json",
  101. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  102. nullable=True,
  103. help="Invalid indexing technique.",
  104. )
  105. dataset_update_parser.add_argument(
  106. "permission",
  107. type=str,
  108. location="json",
  109. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  110. help="Invalid permission.",
  111. )
  112. dataset_update_parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
  113. dataset_update_parser.add_argument(
  114. "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
  115. )
  116. dataset_update_parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
  117. dataset_update_parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
  118. dataset_update_parser.add_argument(
  119. "external_retrieval_model",
  120. type=dict,
  121. required=False,
  122. nullable=True,
  123. location="json",
  124. help="Invalid external retrieval model.",
  125. )
  126. dataset_update_parser.add_argument(
  127. "external_knowledge_id",
  128. type=str,
  129. required=False,
  130. nullable=True,
  131. location="json",
  132. help="Invalid external knowledge id.",
  133. )
  134. dataset_update_parser.add_argument(
  135. "external_knowledge_api_id",
  136. type=str,
  137. required=False,
  138. nullable=True,
  139. location="json",
  140. help="Invalid external knowledge api id.",
  141. )
  142. tag_create_parser = reqparse.RequestParser()
  143. tag_create_parser.add_argument(
  144. "name",
  145. nullable=False,
  146. required=True,
  147. help="Name must be between 1 to 50 characters.",
  148. type=lambda x: x
  149. if x and 1 <= len(x) <= 50
  150. else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
  151. )
  152. tag_update_parser = reqparse.RequestParser()
  153. tag_update_parser.add_argument(
  154. "name",
  155. nullable=False,
  156. required=True,
  157. help="Name must be between 1 to 50 characters.",
  158. type=lambda x: x
  159. if x and 1 <= len(x) <= 50
  160. else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
  161. )
  162. tag_update_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
  163. tag_delete_parser = reqparse.RequestParser()
  164. tag_delete_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
  165. tag_binding_parser = reqparse.RequestParser()
  166. tag_binding_parser.add_argument(
  167. "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
  168. )
  169. tag_binding_parser.add_argument(
  170. "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
  171. )
  172. tag_unbinding_parser = reqparse.RequestParser()
  173. tag_unbinding_parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
  174. tag_unbinding_parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
  175. @service_api_ns.route("/datasets")
  176. class DatasetListApi(DatasetApiResource):
  177. """Resource for datasets."""
  178. @service_api_ns.doc("list_datasets")
  179. @service_api_ns.doc(description="List all datasets")
  180. @service_api_ns.doc(
  181. responses={
  182. 200: "Datasets retrieved successfully",
  183. 401: "Unauthorized - invalid API token",
  184. }
  185. )
  186. def get(self, tenant_id):
  187. """Resource for getting datasets."""
  188. page = request.args.get("page", default=1, type=int)
  189. limit = request.args.get("limit", default=20, type=int)
  190. # provider = request.args.get("provider", default="vendor")
  191. search = request.args.get("keyword", default=None, type=str)
  192. tag_ids = request.args.getlist("tag_ids")
  193. include_all = request.args.get("include_all", default="false").lower() == "true"
  194. datasets, total = DatasetService.get_datasets(
  195. page, limit, tenant_id, current_user, search, tag_ids, include_all
  196. )
  197. # check embedding setting
  198. provider_manager = ProviderManager()
  199. assert isinstance(current_user, Account)
  200. cid = current_user.current_tenant_id
  201. assert cid is not None
  202. configurations = provider_manager.get_configurations(tenant_id=cid)
  203. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  204. model_names = []
  205. for embedding_model in embedding_models:
  206. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  207. data = marshal(datasets, dataset_detail_fields)
  208. for item in data:
  209. if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
  210. item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
  211. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  212. if item_model in model_names:
  213. item["embedding_available"] = True
  214. else:
  215. item["embedding_available"] = False
  216. else:
  217. item["embedding_available"] = True
  218. response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
  219. return response, 200
  220. @service_api_ns.expect(dataset_create_parser)
  221. @service_api_ns.doc("create_dataset")
  222. @service_api_ns.doc(description="Create a new dataset")
  223. @service_api_ns.doc(
  224. responses={
  225. 200: "Dataset created successfully",
  226. 401: "Unauthorized - invalid API token",
  227. 400: "Bad request - invalid parameters",
  228. }
  229. )
  230. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  231. def post(self, tenant_id):
  232. """Resource for creating datasets."""
  233. args = dataset_create_parser.parse_args()
  234. embedding_model_provider = args.get("embedding_model_provider")
  235. embedding_model = args.get("embedding_model")
  236. if embedding_model_provider and embedding_model:
  237. DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
  238. retrieval_model = args.get("retrieval_model")
  239. if (
  240. retrieval_model
  241. and retrieval_model.get("reranking_model")
  242. and retrieval_model.get("reranking_model").get("reranking_provider_name")
  243. ):
  244. DatasetService.check_reranking_model_setting(
  245. tenant_id,
  246. retrieval_model.get("reranking_model").get("reranking_provider_name"),
  247. retrieval_model.get("reranking_model").get("reranking_model_name"),
  248. )
  249. try:
  250. assert isinstance(current_user, Account)
  251. dataset = DatasetService.create_empty_dataset(
  252. tenant_id=tenant_id,
  253. name=args["name"],
  254. description=args["description"],
  255. indexing_technique=args["indexing_technique"],
  256. account=current_user,
  257. permission=args["permission"],
  258. provider=args["provider"],
  259. external_knowledge_api_id=args["external_knowledge_api_id"],
  260. external_knowledge_id=args["external_knowledge_id"],
  261. embedding_model_provider=args["embedding_model_provider"],
  262. embedding_model_name=args["embedding_model"],
  263. retrieval_model=RetrievalModel(**args["retrieval_model"])
  264. if args["retrieval_model"] is not None
  265. else None,
  266. )
  267. except services.errors.dataset.DatasetNameDuplicateError:
  268. raise DatasetNameDuplicateError()
  269. return marshal(dataset, dataset_detail_fields), 200
  270. @service_api_ns.route("/datasets/<uuid:dataset_id>")
  271. class DatasetApi(DatasetApiResource):
  272. """Resource for dataset."""
  273. @service_api_ns.doc("get_dataset")
  274. @service_api_ns.doc(description="Get a specific dataset by ID")
  275. @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
  276. @service_api_ns.doc(
  277. responses={
  278. 200: "Dataset retrieved successfully",
  279. 401: "Unauthorized - invalid API token",
  280. 403: "Forbidden - insufficient permissions",
  281. 404: "Dataset not found",
  282. }
  283. )
  284. def get(self, _, dataset_id):
  285. dataset_id_str = str(dataset_id)
  286. dataset = DatasetService.get_dataset(dataset_id_str)
  287. if dataset is None:
  288. raise NotFound("Dataset not found.")
  289. try:
  290. DatasetService.check_dataset_permission(dataset, current_user)
  291. except services.errors.account.NoPermissionError as e:
  292. raise Forbidden(str(e))
  293. data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
  294. # check embedding setting
  295. provider_manager = ProviderManager()
  296. assert isinstance(current_user, Account)
  297. cid = current_user.current_tenant_id
  298. assert cid is not None
  299. configurations = provider_manager.get_configurations(tenant_id=cid)
  300. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  301. model_names = []
  302. for embedding_model in embedding_models:
  303. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  304. if data.get("indexing_technique") == "high_quality":
  305. item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
  306. if item_model in model_names:
  307. data["embedding_available"] = True
  308. else:
  309. data["embedding_available"] = False
  310. else:
  311. data["embedding_available"] = True
  312. # force update search method to keyword_search if indexing_technique is economic
  313. retrieval_model_dict = data.get("retrieval_model_dict")
  314. if retrieval_model_dict:
  315. retrieval_model_dict["search_method"] = "keyword_search"
  316. if data.get("permission") == "partial_members":
  317. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  318. data.update({"partial_member_list": part_users_list})
  319. return data, 200
  320. @service_api_ns.expect(dataset_update_parser)
  321. @service_api_ns.doc("update_dataset")
  322. @service_api_ns.doc(description="Update an existing dataset")
  323. @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
  324. @service_api_ns.doc(
  325. responses={
  326. 200: "Dataset updated successfully",
  327. 401: "Unauthorized - invalid API token",
  328. 403: "Forbidden - insufficient permissions",
  329. 404: "Dataset not found",
  330. }
  331. )
  332. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  333. def patch(self, _, dataset_id):
  334. dataset_id_str = str(dataset_id)
  335. dataset = DatasetService.get_dataset(dataset_id_str)
  336. if dataset is None:
  337. raise NotFound("Dataset not found.")
  338. args = dataset_update_parser.parse_args()
  339. data = request.get_json()
  340. # check embedding model setting
  341. embedding_model_provider = data.get("embedding_model_provider")
  342. embedding_model = data.get("embedding_model")
  343. if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
  344. if embedding_model_provider and embedding_model:
  345. DatasetService.check_embedding_model_setting(
  346. dataset.tenant_id, embedding_model_provider, embedding_model
  347. )
  348. retrieval_model = data.get("retrieval_model")
  349. if (
  350. retrieval_model
  351. and retrieval_model.get("reranking_model")
  352. and retrieval_model.get("reranking_model").get("reranking_provider_name")
  353. ):
  354. DatasetService.check_reranking_model_setting(
  355. dataset.tenant_id,
  356. retrieval_model.get("reranking_model").get("reranking_provider_name"),
  357. retrieval_model.get("reranking_model").get("reranking_model_name"),
  358. )
  359. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  360. DatasetPermissionService.check_permission(
  361. current_user, dataset, data.get("permission"), data.get("partial_member_list")
  362. )
  363. dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
  364. if dataset is None:
  365. raise NotFound("Dataset not found.")
  366. result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
  367. assert isinstance(current_user, Account)
  368. tenant_id = current_user.current_tenant_id
  369. if data.get("partial_member_list") and data.get("permission") == "partial_members":
  370. DatasetPermissionService.update_partial_member_list(
  371. tenant_id, dataset_id_str, data.get("partial_member_list")
  372. )
  373. # clear partial member list when permission is only_me or all_team_members
  374. elif (
  375. data.get("permission") == DatasetPermissionEnum.ONLY_ME
  376. or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
  377. ):
  378. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  379. partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  380. result_data.update({"partial_member_list": partial_member_list})
  381. return result_data, 200
  382. @service_api_ns.doc("delete_dataset")
  383. @service_api_ns.doc(description="Delete a dataset")
  384. @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
  385. @service_api_ns.doc(
  386. responses={
  387. 204: "Dataset deleted successfully",
  388. 401: "Unauthorized - invalid API token",
  389. 404: "Dataset not found",
  390. 409: "Conflict - dataset is in use",
  391. }
  392. )
  393. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  394. def delete(self, _, dataset_id):
  395. """
  396. Deletes a dataset given its ID.
  397. Args:
  398. _: ignore
  399. dataset_id (UUID): The ID of the dataset to be deleted.
  400. Returns:
  401. dict: A dictionary with a key 'result' and a value 'success'
  402. if the dataset was successfully deleted. Omitted in HTTP response.
  403. int: HTTP status code 204 indicating that the operation was successful.
  404. Raises:
  405. NotFound: If the dataset with the given ID does not exist.
  406. """
  407. dataset_id_str = str(dataset_id)
  408. try:
  409. if DatasetService.delete_dataset(dataset_id_str, current_user):
  410. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  411. return 204
  412. else:
  413. raise NotFound("Dataset not found.")
  414. except services.errors.dataset.DatasetInUseError:
  415. raise DatasetInUseError()
  416. @service_api_ns.route("/datasets/<uuid:dataset_id>/documents/status/<string:action>")
  417. class DocumentStatusApi(DatasetApiResource):
  418. """Resource for batch document status operations."""
  419. @service_api_ns.doc("update_document_status")
  420. @service_api_ns.doc(description="Batch update document status")
  421. @service_api_ns.doc(
  422. params={
  423. "dataset_id": "Dataset ID",
  424. "action": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'",
  425. }
  426. )
  427. @service_api_ns.doc(
  428. responses={
  429. 200: "Document status updated successfully",
  430. 401: "Unauthorized - invalid API token",
  431. 403: "Forbidden - insufficient permissions",
  432. 404: "Dataset not found",
  433. 400: "Bad request - invalid action",
  434. }
  435. )
  436. def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
  437. """
  438. Batch update document status.
  439. Args:
  440. tenant_id: tenant id
  441. dataset_id: dataset id
  442. action: action to perform (Literal["enable", "disable", "archive", "un_archive"])
  443. Returns:
  444. dict: A dictionary with a key 'result' and a value 'success'
  445. int: HTTP status code 200 indicating that the operation was successful.
  446. Raises:
  447. NotFound: If the dataset with the given ID does not exist.
  448. Forbidden: If the user does not have permission.
  449. InvalidActionError: If the action is invalid or cannot be performed.
  450. """
  451. dataset_id_str = str(dataset_id)
  452. dataset = DatasetService.get_dataset(dataset_id_str)
  453. if dataset is None:
  454. raise NotFound("Dataset not found.")
  455. # Check user's permission
  456. try:
  457. DatasetService.check_dataset_permission(dataset, current_user)
  458. except services.errors.account.NoPermissionError as e:
  459. raise Forbidden(str(e))
  460. # Check dataset model setting
  461. DatasetService.check_dataset_model_setting(dataset)
  462. # Get document IDs from request body
  463. data = request.get_json()
  464. document_ids = data.get("document_ids", [])
  465. try:
  466. DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
  467. except services.errors.document.DocumentIndexingError as e:
  468. raise InvalidActionError(str(e))
  469. except ValueError as e:
  470. raise InvalidActionError(str(e))
  471. return {"result": "success"}, 200
  472. @service_api_ns.route("/datasets/tags")
  473. class DatasetTagsApi(DatasetApiResource):
  474. @service_api_ns.doc("list_dataset_tags")
  475. @service_api_ns.doc(description="Get all knowledge type tags")
  476. @service_api_ns.doc(
  477. responses={
  478. 200: "Tags retrieved successfully",
  479. 401: "Unauthorized - invalid API token",
  480. }
  481. )
  482. @validate_dataset_token
  483. @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
  484. def get(self, _, dataset_id):
  485. """Get all knowledge type tags."""
  486. assert isinstance(current_user, Account)
  487. cid = current_user.current_tenant_id
  488. assert cid is not None
  489. tags = TagService.get_tags("knowledge", cid)
  490. return tags, 200
  491. @service_api_ns.expect(tag_create_parser)
  492. @service_api_ns.doc("create_dataset_tag")
  493. @service_api_ns.doc(description="Add a knowledge type tag")
  494. @service_api_ns.doc(
  495. responses={
  496. 200: "Tag created successfully",
  497. 401: "Unauthorized - invalid API token",
  498. 403: "Forbidden - insufficient permissions",
  499. }
  500. )
  501. @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
  502. @validate_dataset_token
  503. def post(self, _, dataset_id):
  504. """Add a knowledge type tag."""
  505. assert isinstance(current_user, Account)
  506. if not (current_user.has_edit_permission or current_user.is_dataset_editor):
  507. raise Forbidden()
  508. args = tag_create_parser.parse_args()
  509. args["type"] = "knowledge"
  510. tag = TagService.save_tags(args)
  511. response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
  512. return response, 200
  513. @service_api_ns.expect(tag_update_parser)
  514. @service_api_ns.doc("update_dataset_tag")
  515. @service_api_ns.doc(description="Update a knowledge type tag")
  516. @service_api_ns.doc(
  517. responses={
  518. 200: "Tag updated successfully",
  519. 401: "Unauthorized - invalid API token",
  520. 403: "Forbidden - insufficient permissions",
  521. }
  522. )
  523. @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
  524. @validate_dataset_token
  525. def patch(self, _, dataset_id):
  526. assert isinstance(current_user, Account)
  527. if not (current_user.has_edit_permission or current_user.is_dataset_editor):
  528. raise Forbidden()
  529. args = tag_update_parser.parse_args()
  530. args["type"] = "knowledge"
  531. tag_id = args["tag_id"]
  532. tag = TagService.update_tags(args, tag_id)
  533. binding_count = TagService.get_tag_binding_count(tag_id)
  534. response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
  535. return response, 200
  536. @service_api_ns.expect(tag_delete_parser)
  537. @service_api_ns.doc("delete_dataset_tag")
  538. @service_api_ns.doc(description="Delete a knowledge type tag")
  539. @service_api_ns.doc(
  540. responses={
  541. 204: "Tag deleted successfully",
  542. 401: "Unauthorized - invalid API token",
  543. 403: "Forbidden - insufficient permissions",
  544. }
  545. )
  546. @validate_dataset_token
  547. def delete(self, _, dataset_id):
  548. """Delete a knowledge type tag."""
  549. assert isinstance(current_user, Account)
  550. if not current_user.has_edit_permission:
  551. raise Forbidden()
  552. args = tag_delete_parser.parse_args()
  553. TagService.delete_tag(args["tag_id"])
  554. return 204
  555. @service_api_ns.route("/datasets/tags/binding")
  556. class DatasetTagBindingApi(DatasetApiResource):
  557. @service_api_ns.expect(tag_binding_parser)
  558. @service_api_ns.doc("bind_dataset_tags")
  559. @service_api_ns.doc(description="Bind tags to a dataset")
  560. @service_api_ns.doc(
  561. responses={
  562. 204: "Tags bound successfully",
  563. 401: "Unauthorized - invalid API token",
  564. 403: "Forbidden - insufficient permissions",
  565. }
  566. )
  567. @validate_dataset_token
  568. def post(self, _, dataset_id):
  569. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  570. assert isinstance(current_user, Account)
  571. if not (current_user.has_edit_permission or current_user.is_dataset_editor):
  572. raise Forbidden()
  573. args = tag_binding_parser.parse_args()
  574. args["type"] = "knowledge"
  575. TagService.save_tag_binding(args)
  576. return 204
  577. @service_api_ns.route("/datasets/tags/unbinding")
  578. class DatasetTagUnbindingApi(DatasetApiResource):
  579. @service_api_ns.expect(tag_unbinding_parser)
  580. @service_api_ns.doc("unbind_dataset_tag")
  581. @service_api_ns.doc(description="Unbind a tag from a dataset")
  582. @service_api_ns.doc(
  583. responses={
  584. 204: "Tag unbound successfully",
  585. 401: "Unauthorized - invalid API token",
  586. 403: "Forbidden - insufficient permissions",
  587. }
  588. )
  589. @validate_dataset_token
  590. def post(self, _, dataset_id):
  591. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  592. assert isinstance(current_user, Account)
  593. if not (current_user.has_edit_permission or current_user.is_dataset_editor):
  594. raise Forbidden()
  595. args = tag_unbinding_parser.parse_args()
  596. args["type"] = "knowledge"
  597. TagService.delete_tag_binding(args)
  598. return 204
  599. @service_api_ns.route("/datasets/<uuid:dataset_id>/tags")
  600. class DatasetTagsBindingStatusApi(DatasetApiResource):
  601. @service_api_ns.doc("get_dataset_tags_binding_status")
  602. @service_api_ns.doc(description="Get tags bound to a specific dataset")
  603. @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
  604. @service_api_ns.doc(
  605. responses={
  606. 200: "Tags retrieved successfully",
  607. 401: "Unauthorized - invalid API token",
  608. }
  609. )
  610. @validate_dataset_token
  611. def get(self, _, *args, **kwargs):
  612. """Get all knowledge type tags."""
  613. dataset_id = kwargs.get("dataset_id")
  614. assert isinstance(current_user, Account)
  615. assert current_user.current_tenant_id is not None
  616. tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
  617. tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
  618. response = {"data": tags_list, "total": len(tags)}
  619. return response, 200