Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

dataset.py 17KB


  1. from flask import request
  2. from flask_restful import marshal, marshal_with, reqparse
  3. from werkzeug.exceptions import Forbidden, NotFound
  4. import services.dataset_service
  5. from controllers.service_api import api
  6. from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
  7. from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token
  8. from core.model_runtime.entities.model_entities import ModelType
  9. from core.plugin.entities.plugin import ModelProviderID
  10. from core.provider_manager import ProviderManager
  11. from fields.dataset_fields import dataset_detail_fields
  12. from fields.tag_fields import tag_fields
  13. from libs.login import current_user
  14. from models.dataset import Dataset, DatasetPermissionEnum
  15. from services.dataset_service import DatasetPermissionService, DatasetService
  16. from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
  17. from services.tag_service import TagService
  18. def _validate_name(name):
  19. if not name or len(name) < 1 or len(name) > 40:
  20. raise ValueError("Name must be between 1 to 40 characters.")
  21. return name
  22. def _validate_description_length(description):
  23. if len(description) > 400:
  24. raise ValueError("Description cannot exceed 400 characters.")
  25. return description
  26. class DatasetListApi(DatasetApiResource):
  27. """Resource for datasets."""
  28. def get(self, tenant_id):
  29. """Resource for getting datasets."""
  30. page = request.args.get("page", default=1, type=int)
  31. limit = request.args.get("limit", default=20, type=int)
  32. # provider = request.args.get("provider", default="vendor")
  33. search = request.args.get("keyword", default=None, type=str)
  34. tag_ids = request.args.getlist("tag_ids")
  35. include_all = request.args.get("include_all", default="false").lower() == "true"
  36. datasets, total = DatasetService.get_datasets(
  37. page, limit, tenant_id, current_user, search, tag_ids, include_all
  38. )
  39. # check embedding setting
  40. provider_manager = ProviderManager()
  41. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  42. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  43. model_names = []
  44. for embedding_model in embedding_models:
  45. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  46. data = marshal(datasets, dataset_detail_fields)
  47. for item in data:
  48. if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
  49. item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
  50. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  51. if item_model in model_names:
  52. item["embedding_available"] = True
  53. else:
  54. item["embedding_available"] = False
  55. else:
  56. item["embedding_available"] = True
  57. response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
  58. return response, 200
  59. def post(self, tenant_id):
  60. """Resource for creating datasets."""
  61. parser = reqparse.RequestParser()
  62. parser.add_argument(
  63. "name",
  64. nullable=False,
  65. required=True,
  66. help="type is required. Name must be between 1 to 40 characters.",
  67. type=_validate_name,
  68. )
  69. parser.add_argument(
  70. "description",
  71. type=str,
  72. nullable=True,
  73. required=False,
  74. default="",
  75. )
  76. parser.add_argument(
  77. "indexing_technique",
  78. type=str,
  79. location="json",
  80. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  81. help="Invalid indexing technique.",
  82. )
  83. parser.add_argument(
  84. "permission",
  85. type=str,
  86. location="json",
  87. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  88. help="Invalid permission.",
  89. required=False,
  90. nullable=False,
  91. )
  92. parser.add_argument(
  93. "external_knowledge_api_id",
  94. type=str,
  95. nullable=True,
  96. required=False,
  97. default="_validate_name",
  98. )
  99. parser.add_argument(
  100. "provider",
  101. type=str,
  102. nullable=True,
  103. required=False,
  104. default="vendor",
  105. )
  106. parser.add_argument(
  107. "external_knowledge_id",
  108. type=str,
  109. nullable=True,
  110. required=False,
  111. )
  112. parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
  113. parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
  114. parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
  115. args = parser.parse_args()
  116. try:
  117. dataset = DatasetService.create_empty_dataset(
  118. tenant_id=tenant_id,
  119. name=args["name"],
  120. description=args["description"],
  121. indexing_technique=args["indexing_technique"],
  122. account=current_user,
  123. permission=args["permission"],
  124. provider=args["provider"],
  125. external_knowledge_api_id=args["external_knowledge_api_id"],
  126. external_knowledge_id=args["external_knowledge_id"],
  127. embedding_model_provider=args["embedding_model_provider"],
  128. embedding_model_name=args["embedding_model"],
  129. retrieval_model=RetrievalModel(**args["retrieval_model"])
  130. if args["retrieval_model"] is not None
  131. else None,
  132. )
  133. except services.errors.dataset.DatasetNameDuplicateError:
  134. raise DatasetNameDuplicateError()
  135. return marshal(dataset, dataset_detail_fields), 200
  136. class DatasetApi(DatasetApiResource):
  137. """Resource for dataset."""
  138. def get(self, _, dataset_id):
  139. dataset_id_str = str(dataset_id)
  140. dataset = DatasetService.get_dataset(dataset_id_str)
  141. if dataset is None:
  142. raise NotFound("Dataset not found.")
  143. try:
  144. DatasetService.check_dataset_permission(dataset, current_user)
  145. except services.errors.account.NoPermissionError as e:
  146. raise Forbidden(str(e))
  147. data = marshal(dataset, dataset_detail_fields)
  148. if data.get("permission") == "partial_members":
  149. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  150. data.update({"partial_member_list": part_users_list})
  151. # check embedding setting
  152. provider_manager = ProviderManager()
  153. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  154. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  155. model_names = []
  156. for embedding_model in embedding_models:
  157. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  158. if data["indexing_technique"] == "high_quality":
  159. item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
  160. if item_model in model_names:
  161. data["embedding_available"] = True
  162. else:
  163. data["embedding_available"] = False
  164. else:
  165. data["embedding_available"] = True
  166. if data.get("permission") == "partial_members":
  167. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  168. data.update({"partial_member_list": part_users_list})
  169. return data, 200
  170. def patch(self, _, dataset_id):
  171. dataset_id_str = str(dataset_id)
  172. dataset = DatasetService.get_dataset(dataset_id_str)
  173. if dataset is None:
  174. raise NotFound("Dataset not found.")
  175. parser = reqparse.RequestParser()
  176. parser.add_argument(
  177. "name",
  178. nullable=False,
  179. help="type is required. Name must be between 1 to 40 characters.",
  180. type=_validate_name,
  181. )
  182. parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
  183. parser.add_argument(
  184. "indexing_technique",
  185. type=str,
  186. location="json",
  187. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  188. nullable=True,
  189. help="Invalid indexing technique.",
  190. )
  191. parser.add_argument(
  192. "permission",
  193. type=str,
  194. location="json",
  195. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  196. help="Invalid permission.",
  197. )
  198. parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
  199. parser.add_argument(
  200. "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
  201. )
  202. parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
  203. parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
  204. parser.add_argument(
  205. "external_retrieval_model",
  206. type=dict,
  207. required=False,
  208. nullable=True,
  209. location="json",
  210. help="Invalid external retrieval model.",
  211. )
  212. parser.add_argument(
  213. "external_knowledge_id",
  214. type=str,
  215. required=False,
  216. nullable=True,
  217. location="json",
  218. help="Invalid external knowledge id.",
  219. )
  220. parser.add_argument(
  221. "external_knowledge_api_id",
  222. type=str,
  223. required=False,
  224. nullable=True,
  225. location="json",
  226. help="Invalid external knowledge api id.",
  227. )
  228. args = parser.parse_args()
  229. data = request.get_json()
  230. # check embedding model setting
  231. if data.get("indexing_technique") == "high_quality":
  232. DatasetService.check_embedding_model_setting(
  233. dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
  234. )
  235. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  236. DatasetPermissionService.check_permission(
  237. current_user, dataset, data.get("permission"), data.get("partial_member_list")
  238. )
  239. dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
  240. if dataset is None:
  241. raise NotFound("Dataset not found.")
  242. result_data = marshal(dataset, dataset_detail_fields)
  243. tenant_id = current_user.current_tenant_id
  244. if data.get("partial_member_list") and data.get("permission") == "partial_members":
  245. DatasetPermissionService.update_partial_member_list(
  246. tenant_id, dataset_id_str, data.get("partial_member_list")
  247. )
  248. # clear partial member list when permission is only_me or all_team_members
  249. elif (
  250. data.get("permission") == DatasetPermissionEnum.ONLY_ME
  251. or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
  252. ):
  253. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  254. partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  255. result_data.update({"partial_member_list": partial_member_list})
  256. return result_data, 200
  257. def delete(self, _, dataset_id):
  258. """
  259. Deletes a dataset given its ID.
  260. Args:
  261. _: ignore
  262. dataset_id (UUID): The ID of the dataset to be deleted.
  263. Returns:
  264. dict: A dictionary with a key 'result' and a value 'success'
  265. if the dataset was successfully deleted. Omitted in HTTP response.
  266. int: HTTP status code 204 indicating that the operation was successful.
  267. Raises:
  268. NotFound: If the dataset with the given ID does not exist.
  269. """
  270. dataset_id_str = str(dataset_id)
  271. try:
  272. if DatasetService.delete_dataset(dataset_id_str, current_user):
  273. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  274. return 204
  275. else:
  276. raise NotFound("Dataset not found.")
  277. except services.errors.dataset.DatasetInUseError:
  278. raise DatasetInUseError()
  279. class DatasetTagsApi(DatasetApiResource):
  280. @validate_dataset_token
  281. @marshal_with(tag_fields)
  282. def get(self, _, dataset_id):
  283. """Get all knowledge type tags."""
  284. tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
  285. return tags, 200
  286. @validate_dataset_token
  287. def post(self, _, dataset_id):
  288. """Add a knowledge type tag."""
  289. if not (current_user.is_editor or current_user.is_dataset_editor):
  290. raise Forbidden()
  291. parser = reqparse.RequestParser()
  292. parser.add_argument(
  293. "name",
  294. nullable=False,
  295. required=True,
  296. help="Name must be between 1 to 50 characters.",
  297. type=DatasetTagsApi._validate_tag_name,
  298. )
  299. args = parser.parse_args()
  300. args["type"] = "knowledge"
  301. tag = TagService.save_tags(args)
  302. response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
  303. return response, 200
  304. @validate_dataset_token
  305. def patch(self, _, dataset_id):
  306. if not (current_user.is_editor or current_user.is_dataset_editor):
  307. raise Forbidden()
  308. parser = reqparse.RequestParser()
  309. parser.add_argument(
  310. "name",
  311. nullable=False,
  312. required=True,
  313. help="Name must be between 1 to 50 characters.",
  314. type=DatasetTagsApi._validate_tag_name,
  315. )
  316. parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
  317. args = parser.parse_args()
  318. args["type"] = "knowledge"
  319. tag = TagService.update_tags(args, args.get("tag_id"))
  320. binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
  321. response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
  322. return response, 200
  323. @validate_dataset_token
  324. def delete(self, _, dataset_id):
  325. """Delete a knowledge type tag."""
  326. if not current_user.is_editor:
  327. raise Forbidden()
  328. parser = reqparse.RequestParser()
  329. parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
  330. args = parser.parse_args()
  331. TagService.delete_tag(args.get("tag_id"))
  332. return 204
  333. @staticmethod
  334. def _validate_tag_name(name):
  335. if not name or len(name) < 1 or len(name) > 50:
  336. raise ValueError("Name must be between 1 to 50 characters.")
  337. return name
  338. class DatasetTagBindingApi(DatasetApiResource):
  339. @validate_dataset_token
  340. def post(self, _, dataset_id):
  341. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  342. if not (current_user.is_editor or current_user.is_dataset_editor):
  343. raise Forbidden()
  344. parser = reqparse.RequestParser()
  345. parser.add_argument(
  346. "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
  347. )
  348. parser.add_argument(
  349. "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
  350. )
  351. args = parser.parse_args()
  352. args["type"] = "knowledge"
  353. TagService.save_tag_binding(args)
  354. return 204
  355. class DatasetTagUnbindingApi(DatasetApiResource):
  356. @validate_dataset_token
  357. def post(self, _, dataset_id):
  358. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  359. if not (current_user.is_editor or current_user.is_dataset_editor):
  360. raise Forbidden()
  361. parser = reqparse.RequestParser()
  362. parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
  363. parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
  364. args = parser.parse_args()
  365. args["type"] = "knowledge"
  366. TagService.delete_tag_binding(args)
  367. return 204
  368. class DatasetTagsBindingStatusApi(DatasetApiResource):
  369. @validate_dataset_token
  370. def get(self, _, *args, **kwargs):
  371. """Get all knowledge type tags."""
  372. dataset_id = kwargs.get("dataset_id")
  373. tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
  374. tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
  375. response = {"data": tags_list, "total": len(tags)}
  376. return response, 200
  377. api.add_resource(DatasetListApi, "/datasets")
  378. api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
  379. api.add_resource(DatasetTagsApi, "/datasets/tags")
  380. api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
  381. api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
  382. api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags")