Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. from flask import request
  2. from flask_restful import marshal, marshal_with, reqparse
  3. from werkzeug.exceptions import Forbidden, NotFound
  4. import services.dataset_service
  5. from controllers.service_api import api
  6. from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
  7. from controllers.service_api.wraps import (
  8. DatasetApiResource,
  9. cloud_edition_billing_rate_limit_check,
  10. validate_dataset_token,
  11. )
  12. from core.model_runtime.entities.model_entities import ModelType
  13. from core.plugin.entities.plugin import ModelProviderID
  14. from core.provider_manager import ProviderManager
  15. from fields.dataset_fields import dataset_detail_fields
  16. from fields.tag_fields import tag_fields
  17. from libs.login import current_user
  18. from models.dataset import Dataset, DatasetPermissionEnum
  19. from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
  20. from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
  21. from services.tag_service import TagService
  22. def _validate_name(name):
  23. if not name or len(name) < 1 or len(name) > 40:
  24. raise ValueError("Name must be between 1 to 40 characters.")
  25. return name
  26. def _validate_description_length(description):
  27. if description and len(description) > 400:
  28. raise ValueError("Description cannot exceed 400 characters.")
  29. return description
  30. class DatasetListApi(DatasetApiResource):
  31. """Resource for datasets."""
  32. def get(self, tenant_id):
  33. """Resource for getting datasets."""
  34. page = request.args.get("page", default=1, type=int)
  35. limit = request.args.get("limit", default=20, type=int)
  36. # provider = request.args.get("provider", default="vendor")
  37. search = request.args.get("keyword", default=None, type=str)
  38. tag_ids = request.args.getlist("tag_ids")
  39. include_all = request.args.get("include_all", default="false").lower() == "true"
  40. datasets, total = DatasetService.get_datasets(
  41. page, limit, tenant_id, current_user, search, tag_ids, include_all
  42. )
  43. # check embedding setting
  44. provider_manager = ProviderManager()
  45. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  46. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  47. model_names = []
  48. for embedding_model in embedding_models:
  49. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  50. data = marshal(datasets, dataset_detail_fields)
  51. for item in data:
  52. if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
  53. item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
  54. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  55. if item_model in model_names:
  56. item["embedding_available"] = True
  57. else:
  58. item["embedding_available"] = False
  59. else:
  60. item["embedding_available"] = True
  61. response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
  62. return response, 200
  63. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  64. def post(self, tenant_id):
  65. """Resource for creating datasets."""
  66. parser = reqparse.RequestParser()
  67. parser.add_argument(
  68. "name",
  69. nullable=False,
  70. required=True,
  71. help="type is required. Name must be between 1 to 40 characters.",
  72. type=_validate_name,
  73. )
  74. parser.add_argument(
  75. "description",
  76. type=_validate_description_length,
  77. nullable=True,
  78. required=False,
  79. default="",
  80. )
  81. parser.add_argument(
  82. "indexing_technique",
  83. type=str,
  84. location="json",
  85. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  86. help="Invalid indexing technique.",
  87. )
  88. parser.add_argument(
  89. "permission",
  90. type=str,
  91. location="json",
  92. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  93. help="Invalid permission.",
  94. required=False,
  95. nullable=False,
  96. )
  97. parser.add_argument(
  98. "external_knowledge_api_id",
  99. type=str,
  100. nullable=True,
  101. required=False,
  102. default="_validate_name",
  103. )
  104. parser.add_argument(
  105. "provider",
  106. type=str,
  107. nullable=True,
  108. required=False,
  109. default="vendor",
  110. )
  111. parser.add_argument(
  112. "external_knowledge_id",
  113. type=str,
  114. nullable=True,
  115. required=False,
  116. )
  117. parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
  118. parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
  119. parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
  120. args = parser.parse_args()
  121. if args.get("embedding_model_provider"):
  122. DatasetService.check_embedding_model_setting(
  123. tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
  124. )
  125. if (
  126. args.get("retrieval_model")
  127. and args.get("retrieval_model").get("reranking_model")
  128. and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
  129. ):
  130. DatasetService.check_reranking_model_setting(
  131. tenant_id,
  132. args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
  133. args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
  134. )
  135. try:
  136. dataset = DatasetService.create_empty_dataset(
  137. tenant_id=tenant_id,
  138. name=args["name"],
  139. description=args["description"],
  140. indexing_technique=args["indexing_technique"],
  141. account=current_user,
  142. permission=args["permission"],
  143. provider=args["provider"],
  144. external_knowledge_api_id=args["external_knowledge_api_id"],
  145. external_knowledge_id=args["external_knowledge_id"],
  146. embedding_model_provider=args["embedding_model_provider"],
  147. embedding_model_name=args["embedding_model"],
  148. retrieval_model=RetrievalModel(**args["retrieval_model"])
  149. if args["retrieval_model"] is not None
  150. else None,
  151. )
  152. except services.errors.dataset.DatasetNameDuplicateError:
  153. raise DatasetNameDuplicateError()
  154. return marshal(dataset, dataset_detail_fields), 200
  155. class DatasetApi(DatasetApiResource):
  156. """Resource for dataset."""
  157. def get(self, _, dataset_id):
  158. dataset_id_str = str(dataset_id)
  159. dataset = DatasetService.get_dataset(dataset_id_str)
  160. if dataset is None:
  161. raise NotFound("Dataset not found.")
  162. try:
  163. DatasetService.check_dataset_permission(dataset, current_user)
  164. except services.errors.account.NoPermissionError as e:
  165. raise Forbidden(str(e))
  166. data = marshal(dataset, dataset_detail_fields)
  167. if data.get("permission") == "partial_members":
  168. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  169. data.update({"partial_member_list": part_users_list})
  170. # check embedding setting
  171. provider_manager = ProviderManager()
  172. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  173. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  174. model_names = []
  175. for embedding_model in embedding_models:
  176. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  177. if data["indexing_technique"] == "high_quality":
  178. item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
  179. if item_model in model_names:
  180. data["embedding_available"] = True
  181. else:
  182. data["embedding_available"] = False
  183. else:
  184. data["embedding_available"] = True
  185. if data.get("permission") == "partial_members":
  186. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  187. data.update({"partial_member_list": part_users_list})
  188. return data, 200
  189. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  190. def patch(self, _, dataset_id):
  191. dataset_id_str = str(dataset_id)
  192. dataset = DatasetService.get_dataset(dataset_id_str)
  193. if dataset is None:
  194. raise NotFound("Dataset not found.")
  195. parser = reqparse.RequestParser()
  196. parser.add_argument(
  197. "name",
  198. nullable=False,
  199. help="type is required. Name must be between 1 to 40 characters.",
  200. type=_validate_name,
  201. )
  202. parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
  203. parser.add_argument(
  204. "indexing_technique",
  205. type=str,
  206. location="json",
  207. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  208. nullable=True,
  209. help="Invalid indexing technique.",
  210. )
  211. parser.add_argument(
  212. "permission",
  213. type=str,
  214. location="json",
  215. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  216. help="Invalid permission.",
  217. )
  218. parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
  219. parser.add_argument(
  220. "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
  221. )
  222. parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
  223. parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
  224. parser.add_argument(
  225. "external_retrieval_model",
  226. type=dict,
  227. required=False,
  228. nullable=True,
  229. location="json",
  230. help="Invalid external retrieval model.",
  231. )
  232. parser.add_argument(
  233. "external_knowledge_id",
  234. type=str,
  235. required=False,
  236. nullable=True,
  237. location="json",
  238. help="Invalid external knowledge id.",
  239. )
  240. parser.add_argument(
  241. "external_knowledge_api_id",
  242. type=str,
  243. required=False,
  244. nullable=True,
  245. location="json",
  246. help="Invalid external knowledge api id.",
  247. )
  248. args = parser.parse_args()
  249. data = request.get_json()
  250. # check embedding model setting
  251. if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
  252. DatasetService.check_embedding_model_setting(
  253. dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
  254. )
  255. if (
  256. data.get("retrieval_model")
  257. and data.get("retrieval_model").get("reranking_model")
  258. and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
  259. ):
  260. DatasetService.check_reranking_model_setting(
  261. dataset.tenant_id,
  262. data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
  263. data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
  264. )
  265. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  266. DatasetPermissionService.check_permission(
  267. current_user, dataset, data.get("permission"), data.get("partial_member_list")
  268. )
  269. dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
  270. if dataset is None:
  271. raise NotFound("Dataset not found.")
  272. result_data = marshal(dataset, dataset_detail_fields)
  273. tenant_id = current_user.current_tenant_id
  274. if data.get("partial_member_list") and data.get("permission") == "partial_members":
  275. DatasetPermissionService.update_partial_member_list(
  276. tenant_id, dataset_id_str, data.get("partial_member_list")
  277. )
  278. # clear partial member list when permission is only_me or all_team_members
  279. elif (
  280. data.get("permission") == DatasetPermissionEnum.ONLY_ME
  281. or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
  282. ):
  283. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  284. partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  285. result_data.update({"partial_member_list": partial_member_list})
  286. return result_data, 200
  287. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  288. def delete(self, _, dataset_id):
  289. """
  290. Deletes a dataset given its ID.
  291. Args:
  292. _: ignore
  293. dataset_id (UUID): The ID of the dataset to be deleted.
  294. Returns:
  295. dict: A dictionary with a key 'result' and a value 'success'
  296. if the dataset was successfully deleted. Omitted in HTTP response.
  297. int: HTTP status code 204 indicating that the operation was successful.
  298. Raises:
  299. NotFound: If the dataset with the given ID does not exist.
  300. """
  301. dataset_id_str = str(dataset_id)
  302. try:
  303. if DatasetService.delete_dataset(dataset_id_str, current_user):
  304. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  305. return 204
  306. else:
  307. raise NotFound("Dataset not found.")
  308. except services.errors.dataset.DatasetInUseError:
  309. raise DatasetInUseError()
  310. class DocumentStatusApi(DatasetApiResource):
  311. """Resource for batch document status operations."""
  312. def patch(self, tenant_id, dataset_id, action):
  313. """
  314. Batch update document status.
  315. Args:
  316. tenant_id: tenant id
  317. dataset_id: dataset id
  318. action: action to perform (enable, disable, archive, un_archive)
  319. Returns:
  320. dict: A dictionary with a key 'result' and a value 'success'
  321. int: HTTP status code 200 indicating that the operation was successful.
  322. Raises:
  323. NotFound: If the dataset with the given ID does not exist.
  324. Forbidden: If the user does not have permission.
  325. InvalidActionError: If the action is invalid or cannot be performed.
  326. """
  327. dataset_id_str = str(dataset_id)
  328. dataset = DatasetService.get_dataset(dataset_id_str)
  329. if dataset is None:
  330. raise NotFound("Dataset not found.")
  331. # Check user's permission
  332. try:
  333. DatasetService.check_dataset_permission(dataset, current_user)
  334. except services.errors.account.NoPermissionError as e:
  335. raise Forbidden(str(e))
  336. # Check dataset model setting
  337. DatasetService.check_dataset_model_setting(dataset)
  338. # Get document IDs from request body
  339. data = request.get_json()
  340. document_ids = data.get("document_ids", [])
  341. try:
  342. DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
  343. except services.errors.document.DocumentIndexingError as e:
  344. raise InvalidActionError(str(e))
  345. except ValueError as e:
  346. raise InvalidActionError(str(e))
  347. return {"result": "success"}, 200
  348. class DatasetTagsApi(DatasetApiResource):
  349. @validate_dataset_token
  350. @marshal_with(tag_fields)
  351. def get(self, _, dataset_id):
  352. """Get all knowledge type tags."""
  353. tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
  354. return tags, 200
  355. @validate_dataset_token
  356. def post(self, _, dataset_id):
  357. """Add a knowledge type tag."""
  358. if not (current_user.is_editor or current_user.is_dataset_editor):
  359. raise Forbidden()
  360. parser = reqparse.RequestParser()
  361. parser.add_argument(
  362. "name",
  363. nullable=False,
  364. required=True,
  365. help="Name must be between 1 to 50 characters.",
  366. type=DatasetTagsApi._validate_tag_name,
  367. )
  368. args = parser.parse_args()
  369. args["type"] = "knowledge"
  370. tag = TagService.save_tags(args)
  371. response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
  372. return response, 200
  373. @validate_dataset_token
  374. def patch(self, _, dataset_id):
  375. if not (current_user.is_editor or current_user.is_dataset_editor):
  376. raise Forbidden()
  377. parser = reqparse.RequestParser()
  378. parser.add_argument(
  379. "name",
  380. nullable=False,
  381. required=True,
  382. help="Name must be between 1 to 50 characters.",
  383. type=DatasetTagsApi._validate_tag_name,
  384. )
  385. parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
  386. args = parser.parse_args()
  387. args["type"] = "knowledge"
  388. tag = TagService.update_tags(args, args.get("tag_id"))
  389. binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
  390. response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
  391. return response, 200
  392. @validate_dataset_token
  393. def delete(self, _, dataset_id):
  394. """Delete a knowledge type tag."""
  395. if not current_user.is_editor:
  396. raise Forbidden()
  397. parser = reqparse.RequestParser()
  398. parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
  399. args = parser.parse_args()
  400. TagService.delete_tag(args.get("tag_id"))
  401. return 204
  402. @staticmethod
  403. def _validate_tag_name(name):
  404. if not name or len(name) < 1 or len(name) > 50:
  405. raise ValueError("Name must be between 1 to 50 characters.")
  406. return name
  407. class DatasetTagBindingApi(DatasetApiResource):
  408. @validate_dataset_token
  409. def post(self, _, dataset_id):
  410. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  411. if not (current_user.is_editor or current_user.is_dataset_editor):
  412. raise Forbidden()
  413. parser = reqparse.RequestParser()
  414. parser.add_argument(
  415. "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
  416. )
  417. parser.add_argument(
  418. "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
  419. )
  420. args = parser.parse_args()
  421. args["type"] = "knowledge"
  422. TagService.save_tag_binding(args)
  423. return 204
  424. class DatasetTagUnbindingApi(DatasetApiResource):
  425. @validate_dataset_token
  426. def post(self, _, dataset_id):
  427. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  428. if not (current_user.is_editor or current_user.is_dataset_editor):
  429. raise Forbidden()
  430. parser = reqparse.RequestParser()
  431. parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
  432. parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
  433. args = parser.parse_args()
  434. args["type"] = "knowledge"
  435. TagService.delete_tag_binding(args)
  436. return 204
  437. class DatasetTagsBindingStatusApi(DatasetApiResource):
  438. @validate_dataset_token
  439. def get(self, _, *args, **kwargs):
  440. """Get all knowledge type tags."""
  441. dataset_id = kwargs.get("dataset_id")
  442. tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
  443. tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
  444. response = {"data": tags_list, "total": len(tags)}
  445. return response, 200
  446. api.add_resource(DatasetListApi, "/datasets")
  447. api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
  448. api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>")
  449. api.add_resource(DatasetTagsApi, "/datasets/tags")
  450. api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
  451. api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
  452. api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags")