Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

dataset.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. from flask import request
  18. from peewee import OperationalError
  19. from api import settings
  20. from api.db import FileSource, StatusEnum
  21. from api.db.db_models import File
  22. from api.db.services.document_service import DocumentService
  23. from api.db.services.file2document_service import File2DocumentService
  24. from api.db.services.file_service import FileService
  25. from api.db.services.knowledgebase_service import KnowledgebaseService
  26. from api.db.services.user_service import TenantService
  27. from api.utils import get_uuid
  28. from api.utils.api_utils import (
  29. check_duplicate_ids,
  30. deep_merge,
  31. get_error_argument_result,
  32. get_error_data_result,
  33. get_parser_config,
  34. get_result,
  35. token_required,
  36. verify_embedding_availability,
  37. )
  38. from api.utils.validation_utils import CreateDatasetReq, UpdateDatasetReq, validate_and_parse_json_request
  39. @manager.route("/datasets", methods=["POST"]) # noqa: F821
  40. @token_required
  41. def create(tenant_id):
  42. """
  43. Create a new dataset.
  44. ---
  45. tags:
  46. - Datasets
  47. security:
  48. - ApiKeyAuth: []
  49. parameters:
  50. - in: header
  51. name: Authorization
  52. type: string
  53. required: true
  54. description: Bearer token for authentication.
  55. - in: body
  56. name: body
  57. description: Dataset creation parameters.
  58. required: true
  59. schema:
  60. type: object
  61. required:
  62. - name
  63. properties:
  64. name:
  65. type: string
  66. description: Name of the dataset.
  67. avatar:
  68. type: string
  69. description: Base64 encoding of the avatar.
  70. description:
  71. type: string
  72. description: Description of the dataset.
  73. embedding_model:
  74. type: string
  75. description: Embedding model Name.
  76. permission:
  77. type: string
  78. enum: ['me', 'team']
  79. description: Dataset permission.
  80. chunk_method:
  81. type: string
  82. enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
  83. "picture", "presentation", "qa", "table", "tag"
  84. ]
  85. description: Chunking method.
  86. pagerank:
  87. type: integer
  88. description: Set page rank.
  89. parser_config:
  90. type: object
  91. description: Parser configuration.
  92. responses:
  93. 200:
  94. description: Successful operation.
  95. schema:
  96. type: object
  97. properties:
  98. data:
  99. type: object
  100. """
  101. # Field name transformations during model dump:
  102. # | Original | Dump Output |
  103. # |----------------|-------------|
  104. # | embedding_model| embd_id |
  105. # | chunk_method | parser_id |
  106. req, err = validate_and_parse_json_request(request, CreateDatasetReq)
  107. if err is not None:
  108. return get_error_argument_result(err)
  109. try:
  110. if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
  111. return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
  112. except OperationalError as e:
  113. logging.exception(e)
  114. return get_error_data_result(message="Database operation failed")
  115. req["parser_config"] = get_parser_config(req["parser_id"], req["parser_config"])
  116. req["id"] = get_uuid()
  117. req["tenant_id"] = tenant_id
  118. req["created_by"] = tenant_id
  119. try:
  120. ok, t = TenantService.get_by_id(tenant_id)
  121. if not ok:
  122. return get_error_data_result(message="Tenant not found")
  123. except OperationalError as e:
  124. logging.exception(e)
  125. return get_error_data_result(message="Database operation failed")
  126. if not req.get("embd_id"):
  127. req["embd_id"] = t.embd_id
  128. else:
  129. ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
  130. if not ok:
  131. return err
  132. try:
  133. if not KnowledgebaseService.save(**req):
  134. return get_error_data_result(message="Create dataset error.(Database error)")
  135. except OperationalError as e:
  136. logging.exception(e)
  137. return get_error_data_result(message="Database operation failed")
  138. try:
  139. ok, k = KnowledgebaseService.get_by_id(req["id"])
  140. if not ok:
  141. return get_error_data_result(message="Dataset created failed")
  142. except OperationalError as e:
  143. logging.exception(e)
  144. return get_error_data_result(message="Database operation failed")
  145. response_data = {}
  146. key_mapping = {
  147. "chunk_num": "chunk_count",
  148. "doc_num": "document_count",
  149. "parser_id": "chunk_method",
  150. "embd_id": "embedding_model",
  151. }
  152. for key, value in k.to_dict().items():
  153. new_key = key_mapping.get(key, key)
  154. response_data[new_key] = value
  155. return get_result(data=response_data)
  156. @manager.route("/datasets", methods=["DELETE"]) # noqa: F821
  157. @token_required
  158. def delete(tenant_id):
  159. """
  160. Delete datasets.
  161. ---
  162. tags:
  163. - Datasets
  164. security:
  165. - ApiKeyAuth: []
  166. parameters:
  167. - in: header
  168. name: Authorization
  169. type: string
  170. required: true
  171. description: Bearer token for authentication.
  172. - in: body
  173. name: body
  174. description: Dataset deletion parameters.
  175. required: true
  176. schema:
  177. type: object
  178. properties:
  179. ids:
  180. type: array
  181. items:
  182. type: string
  183. description: List of dataset IDs to delete.
  184. responses:
  185. 200:
  186. description: Successful operation.
  187. schema:
  188. type: object
  189. """
  190. errors = []
  191. success_count = 0
  192. req = request.json
  193. if not req:
  194. ids = None
  195. else:
  196. ids = req.get("ids")
  197. if not ids:
  198. id_list = []
  199. kbs = KnowledgebaseService.query(tenant_id=tenant_id)
  200. for kb in kbs:
  201. id_list.append(kb.id)
  202. else:
  203. id_list = ids
  204. unique_id_list, duplicate_messages = check_duplicate_ids(id_list, "dataset")
  205. id_list = unique_id_list
  206. for id in id_list:
  207. kbs = KnowledgebaseService.query(id=id, tenant_id=tenant_id)
  208. if not kbs:
  209. errors.append(f"You don't own the dataset {id}")
  210. continue
  211. for doc in DocumentService.query(kb_id=id):
  212. if not DocumentService.remove_document(doc, tenant_id):
  213. errors.append(f"Remove document error for dataset {id}")
  214. continue
  215. f2d = File2DocumentService.get_by_document_id(doc.id)
  216. FileService.filter_delete(
  217. [
  218. File.source_type == FileSource.KNOWLEDGEBASE,
  219. File.id == f2d[0].file_id,
  220. ]
  221. )
  222. File2DocumentService.delete_by_document_id(doc.id)
  223. FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
  224. if not KnowledgebaseService.delete_by_id(id):
  225. errors.append(f"Delete dataset error for {id}")
  226. continue
  227. success_count += 1
  228. if errors:
  229. if success_count > 0:
  230. return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} datasets with {len(errors)} errors")
  231. else:
  232. return get_error_data_result(message="; ".join(errors))
  233. if duplicate_messages:
  234. if success_count > 0:
  235. return get_result(
  236. message=f"Partially deleted {success_count} datasets with {len(duplicate_messages)} errors",
  237. data={"success_count": success_count, "errors": duplicate_messages},
  238. )
  239. else:
  240. return get_error_data_result(message=";".join(duplicate_messages))
  241. return get_result(code=settings.RetCode.SUCCESS)
  242. @manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
  243. @token_required
  244. def update(tenant_id, dataset_id):
  245. """
  246. Update a dataset.
  247. ---
  248. tags:
  249. - Datasets
  250. security:
  251. - ApiKeyAuth: []
  252. parameters:
  253. - in: path
  254. name: dataset_id
  255. type: string
  256. required: true
  257. description: ID of the dataset to update.
  258. - in: header
  259. name: Authorization
  260. type: string
  261. required: true
  262. description: Bearer token for authentication.
  263. - in: body
  264. name: body
  265. description: Dataset update parameters.
  266. required: true
  267. schema:
  268. type: object
  269. properties:
  270. name:
  271. type: string
  272. description: New name of the dataset.
  273. avatar:
  274. type: string
  275. description: Updated base64 encoding of the avatar.
  276. description:
  277. type: string
  278. description: Updated description of the dataset.
  279. embedding_model:
  280. type: string
  281. description: Updated embedding model Name.
  282. permission:
  283. type: string
  284. enum: ['me', 'team']
  285. description: Updated dataset permission.
  286. chunk_method:
  287. type: string
  288. enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
  289. "picture", "presentation", "qa", "table", "tag"
  290. ]
  291. description: Updated chunking method.
  292. pagerank:
  293. type: integer
  294. description: Updated page rank.
  295. parser_config:
  296. type: object
  297. description: Updated parser configuration.
  298. responses:
  299. 200:
  300. description: Successful operation.
  301. schema:
  302. type: object
  303. """
  304. # Field name transformations during model dump:
  305. # | Original | Dump Output |
  306. # |----------------|-------------|
  307. # | embedding_model| embd_id |
  308. # | chunk_method | parser_id |
  309. extras = {"dataset_id": dataset_id}
  310. req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
  311. if err is not None:
  312. return get_error_argument_result(err)
  313. if not req:
  314. return get_error_argument_result(message="No properties were modified")
  315. try:
  316. kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
  317. if kb is None:
  318. return get_error_data_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
  319. except OperationalError as e:
  320. logging.exception(e)
  321. return get_error_data_result(message="Database operation failed")
  322. if req.get("parser_config"):
  323. req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"])
  324. if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id:
  325. if not req.get("parser_config"):
  326. req["parser_config"] = get_parser_config(chunk_method, None)
  327. elif "parser_config" in req and not req["parser_config"]:
  328. del req["parser_config"]
  329. if "name" in req and req["name"].lower() != kb.name.lower():
  330. try:
  331. exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
  332. if exists:
  333. return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
  334. except OperationalError as e:
  335. logging.exception(e)
  336. return get_error_data_result(message="Database operation failed")
  337. if "embd_id" in req:
  338. if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
  339. return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
  340. ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
  341. if not ok:
  342. return err
  343. try:
  344. if not KnowledgebaseService.update_by_id(kb.id, req):
  345. return get_error_data_result(message="Update dataset error.(Database error)")
  346. except OperationalError as e:
  347. logging.exception(e)
  348. return get_error_data_result(message="Database operation failed")
  349. return get_result(code=settings.RetCode.SUCCESS)
  350. @manager.route("/datasets", methods=["GET"]) # noqa: F821
  351. @token_required
  352. def list_datasets(tenant_id):
  353. """
  354. List datasets.
  355. ---
  356. tags:
  357. - Datasets
  358. security:
  359. - ApiKeyAuth: []
  360. parameters:
  361. - in: query
  362. name: id
  363. type: string
  364. required: false
  365. description: Dataset ID to filter.
  366. - in: query
  367. name: name
  368. type: string
  369. required: false
  370. description: Dataset name to filter.
  371. - in: query
  372. name: page
  373. type: integer
  374. required: false
  375. default: 1
  376. description: Page number.
  377. - in: query
  378. name: page_size
  379. type: integer
  380. required: false
  381. default: 1024
  382. description: Number of items per page.
  383. - in: query
  384. name: orderby
  385. type: string
  386. required: false
  387. default: "create_time"
  388. description: Field to order by.
  389. - in: query
  390. name: desc
  391. type: boolean
  392. required: false
  393. default: true
  394. description: Order in descending.
  395. - in: header
  396. name: Authorization
  397. type: string
  398. required: true
  399. description: Bearer token for authentication.
  400. responses:
  401. 200:
  402. description: Successful operation.
  403. schema:
  404. type: array
  405. items:
  406. type: object
  407. """
  408. id = request.args.get("id")
  409. name = request.args.get("name")
  410. if id:
  411. kbs = KnowledgebaseService.get_kb_by_id(id, tenant_id)
  412. if not kbs:
  413. return get_error_data_result(f"You don't own the dataset {id}")
  414. if name:
  415. kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
  416. if not kbs:
  417. return get_error_data_result(f"You don't own the dataset {name}")
  418. page_number = int(request.args.get("page", 1))
  419. items_per_page = int(request.args.get("page_size", 30))
  420. orderby = request.args.get("orderby", "create_time")
  421. if request.args.get("desc", "false").lower() not in ["true", "false"]:
  422. return get_error_data_result("desc should be true or false")
  423. if request.args.get("desc", "true").lower() == "false":
  424. desc = False
  425. else:
  426. desc = True
  427. tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
  428. kbs = KnowledgebaseService.get_list(
  429. [m["tenant_id"] for m in tenants],
  430. tenant_id,
  431. page_number,
  432. items_per_page,
  433. orderby,
  434. desc,
  435. id,
  436. name,
  437. )
  438. renamed_list = []
  439. for kb in kbs:
  440. key_mapping = {
  441. "chunk_num": "chunk_count",
  442. "doc_num": "document_count",
  443. "parser_id": "chunk_method",
  444. "embd_id": "embedding_model",
  445. }
  446. renamed_data = {}
  447. for key, value in kb.items():
  448. new_key = key_mapping.get(key, key)
  449. renamed_data[new_key] = value
  450. renamed_list.append(renamed_data)
  451. return get_result(data=renamed_list)