You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. from flask import request
  2. from flask_restful import marshal, reqparse # type: ignore
  3. from werkzeug.exceptions import Forbidden, NotFound
  4. import services.dataset_service
  5. from controllers.service_api import api
  6. from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
  7. from controllers.service_api.wraps import DatasetApiResource
  8. from core.model_runtime.entities.model_entities import ModelType
  9. from core.plugin.entities.plugin import ModelProviderID
  10. from core.provider_manager import ProviderManager
  11. from fields.dataset_fields import dataset_detail_fields
  12. from libs.login import current_user
  13. from models.dataset import Dataset, DatasetPermissionEnum
  14. from services.dataset_service import DatasetPermissionService, DatasetService
  15. def _validate_name(name):
  16. if not name or len(name) < 1 or len(name) > 40:
  17. raise ValueError("Name must be between 1 to 40 characters.")
  18. return name
  19. def _validate_description_length(description):
  20. if len(description) > 400:
  21. raise ValueError("Description cannot exceed 400 characters.")
  22. return description
  23. class DatasetListApi(DatasetApiResource):
  24. """Resource for datasets."""
  25. def get(self, tenant_id):
  26. """Resource for getting datasets."""
  27. page = request.args.get("page", default=1, type=int)
  28. limit = request.args.get("limit", default=20, type=int)
  29. # provider = request.args.get("provider", default="vendor")
  30. search = request.args.get("keyword", default=None, type=str)
  31. tag_ids = request.args.getlist("tag_ids")
  32. include_all = request.args.get("include_all", default="false").lower() == "true"
  33. datasets, total = DatasetService.get_datasets(
  34. page, limit, tenant_id, current_user, search, tag_ids, include_all
  35. )
  36. # check embedding setting
  37. provider_manager = ProviderManager()
  38. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  39. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  40. model_names = []
  41. for embedding_model in embedding_models:
  42. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  43. data = marshal(datasets, dataset_detail_fields)
  44. for item in data:
  45. if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
  46. item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
  47. item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
  48. if item_model in model_names:
  49. item["embedding_available"] = True
  50. else:
  51. item["embedding_available"] = False
  52. else:
  53. item["embedding_available"] = True
  54. response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
  55. return response, 200
  56. def post(self, tenant_id):
  57. """Resource for creating datasets."""
  58. parser = reqparse.RequestParser()
  59. parser.add_argument(
  60. "name",
  61. nullable=False,
  62. required=True,
  63. help="type is required. Name must be between 1 to 40 characters.",
  64. type=_validate_name,
  65. )
  66. parser.add_argument(
  67. "description",
  68. type=str,
  69. nullable=True,
  70. required=False,
  71. default="",
  72. )
  73. parser.add_argument(
  74. "indexing_technique",
  75. type=str,
  76. location="json",
  77. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  78. help="Invalid indexing technique.",
  79. )
  80. parser.add_argument(
  81. "permission",
  82. type=str,
  83. location="json",
  84. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  85. help="Invalid permission.",
  86. required=False,
  87. nullable=False,
  88. )
  89. parser.add_argument(
  90. "external_knowledge_api_id",
  91. type=str,
  92. nullable=True,
  93. required=False,
  94. default="_validate_name",
  95. )
  96. parser.add_argument(
  97. "provider",
  98. type=str,
  99. nullable=True,
  100. required=False,
  101. default="vendor",
  102. )
  103. parser.add_argument(
  104. "external_knowledge_id",
  105. type=str,
  106. nullable=True,
  107. required=False,
  108. )
  109. args = parser.parse_args()
  110. try:
  111. dataset = DatasetService.create_empty_dataset(
  112. tenant_id=tenant_id,
  113. name=args["name"],
  114. description=args["description"],
  115. indexing_technique=args["indexing_technique"],
  116. account=current_user,
  117. permission=args["permission"],
  118. provider=args["provider"],
  119. external_knowledge_api_id=args["external_knowledge_api_id"],
  120. external_knowledge_id=args["external_knowledge_id"],
  121. )
  122. except services.errors.dataset.DatasetNameDuplicateError:
  123. raise DatasetNameDuplicateError()
  124. return marshal(dataset, dataset_detail_fields), 200
  125. class DatasetApi(DatasetApiResource):
  126. """Resource for dataset."""
  127. def get(self, _, dataset_id):
  128. dataset_id_str = str(dataset_id)
  129. dataset = DatasetService.get_dataset(dataset_id_str)
  130. if dataset is None:
  131. raise NotFound("Dataset not found.")
  132. try:
  133. DatasetService.check_dataset_permission(dataset, current_user)
  134. except services.errors.account.NoPermissionError as e:
  135. raise Forbidden(str(e))
  136. data = marshal(dataset, dataset_detail_fields)
  137. if data.get("permission") == "partial_members":
  138. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  139. data.update({"partial_member_list": part_users_list})
  140. # check embedding setting
  141. provider_manager = ProviderManager()
  142. configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
  143. embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
  144. model_names = []
  145. for embedding_model in embedding_models:
  146. model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
  147. if data["indexing_technique"] == "high_quality":
  148. item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
  149. if item_model in model_names:
  150. data["embedding_available"] = True
  151. else:
  152. data["embedding_available"] = False
  153. else:
  154. data["embedding_available"] = True
  155. if data.get("permission") == "partial_members":
  156. part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  157. data.update({"partial_member_list": part_users_list})
  158. return data, 200
  159. def patch(self, _, dataset_id):
  160. dataset_id_str = str(dataset_id)
  161. dataset = DatasetService.get_dataset(dataset_id_str)
  162. if dataset is None:
  163. raise NotFound("Dataset not found.")
  164. parser = reqparse.RequestParser()
  165. parser.add_argument(
  166. "name",
  167. nullable=False,
  168. help="type is required. Name must be between 1 to 40 characters.",
  169. type=_validate_name,
  170. )
  171. parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
  172. parser.add_argument(
  173. "indexing_technique",
  174. type=str,
  175. location="json",
  176. choices=Dataset.INDEXING_TECHNIQUE_LIST,
  177. nullable=True,
  178. help="Invalid indexing technique.",
  179. )
  180. parser.add_argument(
  181. "permission",
  182. type=str,
  183. location="json",
  184. choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
  185. help="Invalid permission.",
  186. )
  187. parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
  188. parser.add_argument(
  189. "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
  190. )
  191. parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
  192. parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
  193. parser.add_argument(
  194. "external_retrieval_model",
  195. type=dict,
  196. required=False,
  197. nullable=True,
  198. location="json",
  199. help="Invalid external retrieval model.",
  200. )
  201. parser.add_argument(
  202. "external_knowledge_id",
  203. type=str,
  204. required=False,
  205. nullable=True,
  206. location="json",
  207. help="Invalid external knowledge id.",
  208. )
  209. parser.add_argument(
  210. "external_knowledge_api_id",
  211. type=str,
  212. required=False,
  213. nullable=True,
  214. location="json",
  215. help="Invalid external knowledge api id.",
  216. )
  217. args = parser.parse_args()
  218. data = request.get_json()
  219. # check embedding model setting
  220. if data.get("indexing_technique") == "high_quality":
  221. DatasetService.check_embedding_model_setting(
  222. dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
  223. )
  224. # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
  225. DatasetPermissionService.check_permission(
  226. current_user, dataset, data.get("permission"), data.get("partial_member_list")
  227. )
  228. dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
  229. if dataset is None:
  230. raise NotFound("Dataset not found.")
  231. result_data = marshal(dataset, dataset_detail_fields)
  232. tenant_id = current_user.current_tenant_id
  233. if data.get("partial_member_list") and data.get("permission") == "partial_members":
  234. DatasetPermissionService.update_partial_member_list(
  235. tenant_id, dataset_id_str, data.get("partial_member_list")
  236. )
  237. # clear partial member list when permission is only_me or all_team_members
  238. elif (
  239. data.get("permission") == DatasetPermissionEnum.ONLY_ME
  240. or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
  241. ):
  242. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  243. partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
  244. result_data.update({"partial_member_list": partial_member_list})
  245. return result_data, 200
  246. def delete(self, _, dataset_id):
  247. """
  248. Deletes a dataset given its ID.
  249. Args:
  250. _: ignore
  251. dataset_id (UUID): The ID of the dataset to be deleted.
  252. Returns:
  253. dict: A dictionary with a key 'result' and a value 'success'
  254. if the dataset was successfully deleted. Omitted in HTTP response.
  255. int: HTTP status code 204 indicating that the operation was successful.
  256. Raises:
  257. NotFound: If the dataset with the given ID does not exist.
  258. """
  259. dataset_id_str = str(dataset_id)
  260. try:
  261. if DatasetService.delete_dataset(dataset_id_str, current_user):
  262. DatasetPermissionService.clear_partial_member_list(dataset_id_str)
  263. return {"result": "success"}, 204
  264. else:
  265. raise NotFound("Dataset not found.")
  266. except services.errors.dataset.DatasetInUseError:
  267. raise DatasetInUseError()
  268. api.add_resource(DatasetListApi, "/datasets")
  269. api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")