Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

dataset.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. from typing import Literal
  2. from flask import request
  3. from flask_restful import marshal, marshal_with, reqparse
  4. from werkzeug.exceptions import Forbidden, NotFound
  5. import services.dataset_service
  6. from controllers.service_api import api
  7. from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
  8. from controllers.service_api.wraps import (
  9. DatasetApiResource,
  10. cloud_edition_billing_rate_limit_check,
  11. validate_dataset_token,
  12. )
  13. from core.model_runtime.entities.model_entities import ModelType
  14. from core.plugin.entities.plugin import ModelProviderID
  15. from core.provider_manager import ProviderManager
  16. from fields.dataset_fields import dataset_detail_fields
  17. from fields.tag_fields import tag_fields
  18. from libs.login import current_user
  19. from models.dataset import Dataset, DatasetPermissionEnum
  20. from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
  21. from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
  22. from services.tag_service import TagService
  23. def _validate_name(name):
  24. if not name or len(name) < 1 or len(name) > 40:
  25. raise ValueError("Name must be between 1 to 40 characters.")
  26. return name
  27. def _validate_description_length(description):
  28. if description and len(description) > 400:
  29. raise ValueError("Description cannot exceed 400 characters.")
  30. return description
  31. class DatasetListApi(DatasetApiResource):
  32. """Resource for datasets."""
  33. def get(self, tenant_id):
  34. """Resource for getting datasets."""
  35. page = request.args.get("page", default=1, type=int)
  36. limit = request.args.get("limit", default=20, type=int)
  37. # provider = request.args.get("provider", default="vendor")
  38. search = request.args.get("keyword", default=None, type=str)
  39. tag_ids = request.args.getlist("tag_ids")
  40. include_all = request.args.get("include_all", default="false").lower() == "true"
  41. datasets, total = DatasetService.get_datasets(
  42. page, limit, tenant_id, current_user, search, tag_ids, include_all
  43. )
  44. # check embedding setting
  45. provider_manager = ProviderManager()
  46. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  47. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  48. model_names = []
  49. for embedding_model in embedding_models:
  50. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  51. data = marshal(datasets, dataset_detail_fields)
  52. for item in data:
  53. if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
  54. item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
  55. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  56. if item_model in model_names:
  57. item["embedding_available"] = True
  58. else:
  59. item["embedding_available"] = False
  60. else:
  61. item["embedding_available"] = True
  62. response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
  63. return response, 200
  64. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  65. def post(self, tenant_id):
  66. """Resource for creating datasets."""
  67. parser = reqparse.RequestParser()
  68. parser.add_argument(
  69. "name",
  70. nullable=False,
  71. required=True,
  72. help="type is required. Name must be between 1 to 40 characters.",
  73. type=_validate_name,
  74. )
  75. parser.add_argument(
  76. "description",
  77. type=_validate_description_length,
  78. nullable=True,
  79. required=False,
  80. default="",
  81. )
  82. parser.add_argument(
  83. "indexing_technique",
  84. type=str,
  85. location="json",
  86. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  87. help="Invalid indexing technique.",
  88. )
  89. parser.add_argument(
  90. "permission",
  91. type=str,
  92. location="json",
  93. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  94. help="Invalid permission.",
  95. required=False,
  96. nullable=False,
  97. )
  98. parser.add_argument(
  99. "external_knowledge_api_id",
  100. type=str,
  101. nullable=True,
  102. required=False,
  103. default="_validate_name",
  104. )
  105. parser.add_argument(
  106. "provider",
  107. type=str,
  108. nullable=True,
  109. required=False,
  110. default="vendor",
  111. )
  112. parser.add_argument(
  113. "external_knowledge_id",
  114. type=str,
  115. nullable=True,
  116. required=False,
  117. )
  118. parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
  119. parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
  120. parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
  121. args = parser.parse_args()
  122. if args.get("embedding_model_provider"):
  123. DatasetService.check_embedding_model_setting(
  124. tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
  125. )
  126. if (
  127. args.get("retrieval_model")
  128. and args.get("retrieval_model").get("reranking_model")
  129. and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
  130. ):
  131. DatasetService.check_reranking_model_setting(
  132. tenant_id,
  133. args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
  134. args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
  135. )
  136. try:
  137. dataset = DatasetService.create_empty_dataset(
  138. tenant_id=tenant_id,
  139. name=args["name"],
  140. description=args["description"],
  141. indexing_technique=args["indexing_technique"],
  142. account=current_user,
  143. permission=args["permission"],
  144. provider=args["provider"],
  145. external_knowledge_api_id=args["external_knowledge_api_id"],
  146. external_knowledge_id=args["external_knowledge_id"],
  147. embedding_model_provider=args["embedding_model_provider"],
  148. embedding_model_name=args["embedding_model"],
  149. retrieval_model=RetrievalModel(**args["retrieval_model"])
  150. if args["retrieval_model"] is not None
  151. else None,
  152. )
  153. except services.errors.dataset.DatasetNameDuplicateError:
  154. raise DatasetNameDuplicateError()
  155. return marshal(dataset, dataset_detail_fields), 200
  156. class DatasetApi(DatasetApiResource):
  157. """Resource for dataset."""
  158. def get(self, _, dataset_id):
  159. dataset_id_str = str(dataset_id)
  160. dataset = DatasetService.get_dataset(dataset_id_str)
  161. if dataset is None:
  162. raise NotFound("Dataset not found.")
  163. try:
  164. DatasetService.check_dataset_permission(dataset, current_user)
  165. except services.errors.account.NoPermissionError as e:
  166. raise Forbidden(str(e))
  167. data = marshal(dataset, dataset_detail_fields)
  168. if data.get("permission") == "partial_members":
  169. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  170. data.update({"partial_member_list": part_users_list})
  171. # check embedding setting
  172. provider_manager = ProviderManager()
  173. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  174. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  175. model_names = []
  176. for embedding_model in embedding_models:
  177. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  178. if data["indexing_technique"] == "high_quality":
  179. item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
  180. if item_model in model_names:
  181. data["embedding_available"] = True
  182. else:
  183. data["embedding_available"] = False
  184. else:
  185. data["embedding_available"] = True
  186. if data.get("permission") == "partial_members":
  187. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  188. data.update({"partial_member_list": part_users_list})
  189. return data, 200
  190. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  191. def patch(self, _, dataset_id):
  192. dataset_id_str = str(dataset_id)
  193. dataset = DatasetService.get_dataset(dataset_id_str)
  194. if dataset is None:
  195. raise NotFound("Dataset not found.")
  196. parser = reqparse.RequestParser()
  197. parser.add_argument(
  198. "name",
  199. nullable=False,
  200. help="type is required. Name must be between 1 to 40 characters.",
  201. type=_validate_name,
  202. )
  203. parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
  204. parser.add_argument(
  205. "indexing_technique",
  206. type=str,
  207. location="json",
  208. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  209. nullable=True,
  210. help="Invalid indexing technique.",
  211. )
  212. parser.add_argument(
  213. "permission",
  214. type=str,
  215. location="json",
  216. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  217. help="Invalid permission.",
  218. )
  219. parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
  220. parser.add_argument(
  221. "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
  222. )
  223. parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
  224. parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
  225. parser.add_argument(
  226. "external_retrieval_model",
  227. type=dict,
  228. required=False,
  229. nullable=True,
  230. location="json",
  231. help="Invalid external retrieval model.",
  232. )
  233. parser.add_argument(
  234. "external_knowledge_id",
  235. type=str,
  236. required=False,
  237. nullable=True,
  238. location="json",
  239. help="Invalid external knowledge id.",
  240. )
  241. parser.add_argument(
  242. "external_knowledge_api_id",
  243. type=str,
  244. required=False,
  245. nullable=True,
  246. location="json",
  247. help="Invalid external knowledge api id.",
  248. )
  249. args = parser.parse_args()
  250. data = request.get_json()
  251. # check embedding model setting
  252. if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
  253. DatasetService.check_embedding_model_setting(
  254. dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
  255. )
  256. if (
  257. data.get("retrieval_model")
  258. and data.get("retrieval_model").get("reranking_model")
  259. and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
  260. ):
  261. DatasetService.check_reranking_model_setting(
  262. dataset.tenant_id,
  263. data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
  264. data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
  265. )
  266. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  267. DatasetPermissionService.check_permission(
  268. current_user, dataset, data.get("permission"), data.get("partial_member_list")
  269. )
  270. dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
  271. if dataset is None:
  272. raise NotFound("Dataset not found.")
  273. result_data = marshal(dataset, dataset_detail_fields)
  274. tenant_id = current_user.current_tenant_id
  275. if data.get("partial_member_list") and data.get("permission") == "partial_members":
  276. DatasetPermissionService.update_partial_member_list(
  277. tenant_id, dataset_id_str, data.get("partial_member_list")
  278. )
  279. # clear partial member list when permission is only_me or all_team_members
  280. elif (
  281. data.get("permission") == DatasetPermissionEnum.ONLY_ME
  282. or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
  283. ):
  284. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  285. partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  286. result_data.update({"partial_member_list": partial_member_list})
  287. return result_data, 200
  288. @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
  289. def delete(self, _, dataset_id):
  290. """
  291. Deletes a dataset given its ID.
  292. Args:
  293. _: ignore
  294. dataset_id (UUID): The ID of the dataset to be deleted.
  295. Returns:
  296. dict: A dictionary with a key 'result' and a value 'success'
  297. if the dataset was successfully deleted. Omitted in HTTP response.
  298. int: HTTP status code 204 indicating that the operation was successful.
  299. Raises:
  300. NotFound: If the dataset with the given ID does not exist.
  301. """
  302. dataset_id_str = str(dataset_id)
  303. try:
  304. if DatasetService.delete_dataset(dataset_id_str, current_user):
  305. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  306. return 204
  307. else:
  308. raise NotFound("Dataset not found.")
  309. except services.errors.dataset.DatasetInUseError:
  310. raise DatasetInUseError()
  311. class DocumentStatusApi(DatasetApiResource):
  312. """Resource for batch document status operations."""
  313. def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
  314. """
  315. Batch update document status.
  316. Args:
  317. tenant_id: tenant id
  318. dataset_id: dataset id
  319. action: action to perform (Literal["enable", "disable", "archive", "un_archive"])
  320. Returns:
  321. dict: A dictionary with a key 'result' and a value 'success'
  322. int: HTTP status code 200 indicating that the operation was successful.
  323. Raises:
  324. NotFound: If the dataset with the given ID does not exist.
  325. Forbidden: If the user does not have permission.
  326. InvalidActionError: If the action is invalid or cannot be performed.
  327. """
  328. dataset_id_str = str(dataset_id)
  329. dataset = DatasetService.get_dataset(dataset_id_str)
  330. if dataset is None:
  331. raise NotFound("Dataset not found.")
  332. # Check user's permission
  333. try:
  334. DatasetService.check_dataset_permission(dataset, current_user)
  335. except services.errors.account.NoPermissionError as e:
  336. raise Forbidden(str(e))
  337. # Check dataset model setting
  338. DatasetService.check_dataset_model_setting(dataset)
  339. # Get document IDs from request body
  340. data = request.get_json()
  341. document_ids = data.get("document_ids", [])
  342. try:
  343. DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
  344. except services.errors.document.DocumentIndexingError as e:
  345. raise InvalidActionError(str(e))
  346. except ValueError as e:
  347. raise InvalidActionError(str(e))
  348. return {"result": "success"}, 200
  349. class DatasetTagsApi(DatasetApiResource):
  350. @validate_dataset_token
  351. @marshal_with(tag_fields)
  352. def get(self, _, dataset_id):
  353. """Get all knowledge type tags."""
  354. tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
  355. return tags, 200
  356. @validate_dataset_token
  357. def post(self, _, dataset_id):
  358. """Add a knowledge type tag."""
  359. if not (current_user.is_editor or current_user.is_dataset_editor):
  360. raise Forbidden()
  361. parser = reqparse.RequestParser()
  362. parser.add_argument(
  363. "name",
  364. nullable=False,
  365. required=True,
  366. help="Name must be between 1 to 50 characters.",
  367. type=DatasetTagsApi._validate_tag_name,
  368. )
  369. args = parser.parse_args()
  370. args["type"] = "knowledge"
  371. tag = TagService.save_tags(args)
  372. response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
  373. return response, 200
  374. @validate_dataset_token
  375. def patch(self, _, dataset_id):
  376. if not (current_user.is_editor or current_user.is_dataset_editor):
  377. raise Forbidden()
  378. parser = reqparse.RequestParser()
  379. parser.add_argument(
  380. "name",
  381. nullable=False,
  382. required=True,
  383. help="Name must be between 1 to 50 characters.",
  384. type=DatasetTagsApi._validate_tag_name,
  385. )
  386. parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
  387. args = parser.parse_args()
  388. args["type"] = "knowledge"
  389. tag = TagService.update_tags(args, args.get("tag_id"))
  390. binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
  391. response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
  392. return response, 200
  393. @validate_dataset_token
  394. def delete(self, _, dataset_id):
  395. """Delete a knowledge type tag."""
  396. if not current_user.is_editor:
  397. raise Forbidden()
  398. parser = reqparse.RequestParser()
  399. parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
  400. args = parser.parse_args()
  401. TagService.delete_tag(args.get("tag_id"))
  402. return 204
  403. @staticmethod
  404. def _validate_tag_name(name):
  405. if not name or len(name) < 1 or len(name) > 50:
  406. raise ValueError("Name must be between 1 to 50 characters.")
  407. return name
  408. class DatasetTagBindingApi(DatasetApiResource):
  409. @validate_dataset_token
  410. def post(self, _, dataset_id):
  411. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  412. if not (current_user.is_editor or current_user.is_dataset_editor):
  413. raise Forbidden()
  414. parser = reqparse.RequestParser()
  415. parser.add_argument(
  416. "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
  417. )
  418. parser.add_argument(
  419. "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
  420. )
  421. args = parser.parse_args()
  422. args["type"] = "knowledge"
  423. TagService.save_tag_binding(args)
  424. return 204
  425. class DatasetTagUnbindingApi(DatasetApiResource):
  426. @validate_dataset_token
  427. def post(self, _, dataset_id):
  428. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  429. if not (current_user.is_editor or current_user.is_dataset_editor):
  430. raise Forbidden()
  431. parser = reqparse.RequestParser()
  432. parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
  433. parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
  434. args = parser.parse_args()
  435. args["type"] = "knowledge"
  436. TagService.delete_tag_binding(args)
  437. return 204
  438. class DatasetTagsBindingStatusApi(DatasetApiResource):
  439. @validate_dataset_token
  440. def get(self, _, *args, **kwargs):
  441. """Get all knowledge type tags."""
  442. dataset_id = kwargs.get("dataset_id")
  443. tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
  444. tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
  445. response = {"data": tags_list, "total": len(tags)}
  446. return response, 200
  447. api.add_resource(DatasetListApi, "/datasets")
  448. api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
  449. api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>")
  450. api.add_resource(DatasetTagsApi, "/datasets/tags")
  451. api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
  452. api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
  453. api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags")