Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

dataset.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. from flask import request
  2. from flask_restful import marshal, reqparse
  3. from werkzeug.exceptions import Forbidden, NotFound
  4. import services.dataset_service
  5. from controllers.service_api import api
  6. from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
  7. from controllers.service_api.wraps import DatasetApiResource
  8. from core.model_runtime.entities.model_entities import ModelType
  9. from core.plugin.entities.plugin import ModelProviderID
  10. from core.provider_manager import ProviderManager
  11. from fields.dataset_fields import dataset_detail_fields
  12. from libs.login import current_user
  13. from models.dataset import Dataset, DatasetPermissionEnum
  14. from services.dataset_service import DatasetPermissionService, DatasetService
  15. from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
  16. def _validate_name(name):
  17. if not name or len(name) < 1 or len(name) > 40:
  18. raise ValueError("Name must be between 1 to 40 characters.")
  19. return name
  20. def _validate_description_length(description):
  21. if len(description) > 400:
  22. raise ValueError("Description cannot exceed 400 characters.")
  23. return description
  24. class DatasetListApi(DatasetApiResource):
  25. """Resource for datasets."""
  26. def get(self, tenant_id):
  27. """Resource for getting datasets."""
  28. page = request.args.get("page", default=1, type=int)
  29. limit = request.args.get("limit", default=20, type=int)
  30. # provider = request.args.get("provider", default="vendor")
  31. search = request.args.get("keyword", default=None, type=str)
  32. tag_ids = request.args.getlist("tag_ids")
  33. include_all = request.args.get("include_all", default="false").lower() == "true"
  34. datasets, total = DatasetService.get_datasets(
  35. page, limit, tenant_id, current_user, search, tag_ids, include_all
  36. )
  37. # check embedding setting
  38. provider_manager = ProviderManager()
  39. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  40. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  41. model_names = []
  42. for embedding_model in embedding_models:
  43. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  44. data = marshal(datasets, dataset_detail_fields)
  45. for item in data:
  46. if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
  47. item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
  48. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  49. if item_model in model_names:
  50. item["embedding_available"] = True
  51. else:
  52. item["embedding_available"] = False
  53. else:
  54. item["embedding_available"] = True
  55. response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
  56. return response, 200
  57. def post(self, tenant_id):
  58. """Resource for creating datasets."""
  59. parser = reqparse.RequestParser()
  60. parser.add_argument(
  61. "name",
  62. nullable=False,
  63. required=True,
  64. help="type is required. Name must be between 1 to 40 characters.",
  65. type=_validate_name,
  66. )
  67. parser.add_argument(
  68. "description",
  69. type=str,
  70. nullable=True,
  71. required=False,
  72. default="",
  73. )
  74. parser.add_argument(
  75. "indexing_technique",
  76. type=str,
  77. location="json",
  78. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  79. help="Invalid indexing technique.",
  80. )
  81. parser.add_argument(
  82. "permission",
  83. type=str,
  84. location="json",
  85. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  86. help="Invalid permission.",
  87. required=False,
  88. nullable=False,
  89. )
  90. parser.add_argument(
  91. "external_knowledge_api_id",
  92. type=str,
  93. nullable=True,
  94. required=False,
  95. default="_validate_name",
  96. )
  97. parser.add_argument(
  98. "provider",
  99. type=str,
  100. nullable=True,
  101. required=False,
  102. default="vendor",
  103. )
  104. parser.add_argument(
  105. "external_knowledge_id",
  106. type=str,
  107. nullable=True,
  108. required=False,
  109. )
  110. parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
  111. parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
  112. parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
  113. args = parser.parse_args()
  114. try:
  115. dataset = DatasetService.create_empty_dataset(
  116. tenant_id=tenant_id,
  117. name=args["name"],
  118. description=args["description"],
  119. indexing_technique=args["indexing_technique"],
  120. account=current_user,
  121. permission=args["permission"],
  122. provider=args["provider"],
  123. external_knowledge_api_id=args["external_knowledge_api_id"],
  124. external_knowledge_id=args["external_knowledge_id"],
  125. embedding_model_provider=args["embedding_model_provider"],
  126. embedding_model_name=args["embedding_model"],
  127. retrieval_model=RetrievalModel(**args["retrieval_model"])
  128. if args["retrieval_model"] is not None
  129. else None,
  130. )
  131. except services.errors.dataset.DatasetNameDuplicateError:
  132. raise DatasetNameDuplicateError()
  133. return marshal(dataset, dataset_detail_fields), 200
  134. class DatasetApi(DatasetApiResource):
  135. """Resource for dataset."""
  136. def get(self, _, dataset_id):
  137. dataset_id_str = str(dataset_id)
  138. dataset = DatasetService.get_dataset(dataset_id_str)
  139. if dataset is None:
  140. raise NotFound("Dataset not found.")
  141. try:
  142. DatasetService.check_dataset_permission(dataset, current_user)
  143. except services.errors.account.NoPermissionError as e:
  144. raise Forbidden(str(e))
  145. data = marshal(dataset, dataset_detail_fields)
  146. if data.get("permission") == "partial_members":
  147. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  148. data.update({"partial_member_list": part_users_list})
  149. # check embedding setting
  150. provider_manager = ProviderManager()
  151. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  152. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  153. model_names = []
  154. for embedding_model in embedding_models:
  155. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  156. if data["indexing_technique"] == "high_quality":
  157. item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
  158. if item_model in model_names:
  159. data["embedding_available"] = True
  160. else:
  161. data["embedding_available"] = False
  162. else:
  163. data["embedding_available"] = True
  164. if data.get("permission") == "partial_members":
  165. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  166. data.update({"partial_member_list": part_users_list})
  167. return data, 200
  168. def patch(self, _, dataset_id):
  169. dataset_id_str = str(dataset_id)
  170. dataset = DatasetService.get_dataset(dataset_id_str)
  171. if dataset is None:
  172. raise NotFound("Dataset not found.")
  173. parser = reqparse.RequestParser()
  174. parser.add_argument(
  175. "name",
  176. nullable=False,
  177. help="type is required. Name must be between 1 to 40 characters.",
  178. type=_validate_name,
  179. )
  180. parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
  181. parser.add_argument(
  182. "indexing_technique",
  183. type=str,
  184. location="json",
  185. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  186. nullable=True,
  187. help="Invalid indexing technique.",
  188. )
  189. parser.add_argument(
  190. "permission",
  191. type=str,
  192. location="json",
  193. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  194. help="Invalid permission.",
  195. )
  196. parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
  197. parser.add_argument(
  198. "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
  199. )
  200. parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
  201. parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
  202. parser.add_argument(
  203. "external_retrieval_model",
  204. type=dict,
  205. required=False,
  206. nullable=True,
  207. location="json",
  208. help="Invalid external retrieval model.",
  209. )
  210. parser.add_argument(
  211. "external_knowledge_id",
  212. type=str,
  213. required=False,
  214. nullable=True,
  215. location="json",
  216. help="Invalid external knowledge id.",
  217. )
  218. parser.add_argument(
  219. "external_knowledge_api_id",
  220. type=str,
  221. required=False,
  222. nullable=True,
  223. location="json",
  224. help="Invalid external knowledge api id.",
  225. )
  226. args = parser.parse_args()
  227. data = request.get_json()
  228. # check embedding model setting
  229. if data.get("indexing_technique") == "high_quality":
  230. DatasetService.check_embedding_model_setting(
  231. dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
  232. )
  233. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  234. DatasetPermissionService.check_permission(
  235. current_user, dataset, data.get("permission"), data.get("partial_member_list")
  236. )
  237. dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
  238. if dataset is None:
  239. raise NotFound("Dataset not found.")
  240. result_data = marshal(dataset, dataset_detail_fields)
  241. tenant_id = current_user.current_tenant_id
  242. if data.get("partial_member_list") and data.get("permission") == "partial_members":
  243. DatasetPermissionService.update_partial_member_list(
  244. tenant_id, dataset_id_str, data.get("partial_member_list")
  245. )
  246. # clear partial member list when permission is only_me or all_team_members
  247. elif (
  248. data.get("permission") == DatasetPermissionEnum.ONLY_ME
  249. or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
  250. ):
  251. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  252. partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  253. result_data.update({"partial_member_list": partial_member_list})
  254. return result_data, 200
  255. def delete(self, _, dataset_id):
  256. """
  257. Deletes a dataset given its ID.
  258. Args:
  259. _: ignore
  260. dataset_id (UUID): The ID of the dataset to be deleted.
  261. Returns:
  262. dict: A dictionary with a key 'result' and a value 'success'
  263. if the dataset was successfully deleted. Omitted in HTTP response.
  264. int: HTTP status code 204 indicating that the operation was successful.
  265. Raises:
  266. NotFound: If the dataset with the given ID does not exist.
  267. """
  268. dataset_id_str = str(dataset_id)
  269. try:
  270. if DatasetService.delete_dataset(dataset_id_str, current_user):
  271. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  272. return 204
  273. else:
  274. raise NotFound("Dataset not found.")
  275. except services.errors.dataset.DatasetInUseError:
  276. raise DatasetInUseError()
  277. api.add_resource(DatasetListApi, "/datasets")
  278. api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")