Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

dataset.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. from flask import request
  18. from peewee import OperationalError
  19. from api.db import FileSource, StatusEnum
  20. from api.db.db_models import File
  21. from api.db.services.document_service import DocumentService
  22. from api.db.services.file2document_service import File2DocumentService
  23. from api.db.services.file_service import FileService
  24. from api.db.services.knowledgebase_service import KnowledgebaseService
  25. from api.db.services.user_service import TenantService
  26. from api.utils import get_uuid
  27. from api.utils.api_utils import (
  28. deep_merge,
  29. get_error_argument_result,
  30. get_error_data_result,
  31. get_error_operating_result,
  32. get_error_permission_result,
  33. get_parser_config,
  34. get_result,
  35. remap_dictionary_keys,
  36. token_required,
  37. verify_embedding_availability,
  38. )
  39. from api.utils.validation_utils import (
  40. CreateDatasetReq,
  41. DeleteDatasetReq,
  42. ListDatasetReq,
  43. UpdateDatasetReq,
  44. validate_and_parse_json_request,
  45. validate_and_parse_request_args,
  46. )
  47. @manager.route("/datasets", methods=["POST"]) # noqa: F821
  48. @token_required
  49. def create(tenant_id):
  50. """
  51. Create a new dataset.
  52. ---
  53. tags:
  54. - Datasets
  55. security:
  56. - ApiKeyAuth: []
  57. parameters:
  58. - in: header
  59. name: Authorization
  60. type: string
  61. required: true
  62. description: Bearer token for authentication.
  63. - in: body
  64. name: body
  65. description: Dataset creation parameters.
  66. required: true
  67. schema:
  68. type: object
  69. required:
  70. - name
  71. properties:
  72. name:
  73. type: string
  74. description: Name of the dataset.
  75. avatar:
  76. type: string
  77. description: Base64 encoding of the avatar.
  78. description:
  79. type: string
  80. description: Description of the dataset.
  81. embedding_model:
  82. type: string
  83. description: Embedding model Name.
  84. permission:
  85. type: string
  86. enum: ['me', 'team']
  87. description: Dataset permission.
  88. chunk_method:
  89. type: string
  90. enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
  91. "picture", "presentation", "qa", "table", "tag"
  92. ]
  93. description: Chunking method.
  94. pagerank:
  95. type: integer
  96. description: Set page rank.
  97. parser_config:
  98. type: object
  99. description: Parser configuration.
  100. responses:
  101. 200:
  102. description: Successful operation.
  103. schema:
  104. type: object
  105. properties:
  106. data:
  107. type: object
  108. """
  109. # Field name transformations during model dump:
  110. # | Original | Dump Output |
  111. # |----------------|-------------|
  112. # | embedding_model| embd_id |
  113. # | chunk_method | parser_id |
  114. req, err = validate_and_parse_json_request(request, CreateDatasetReq)
  115. if err is not None:
  116. return get_error_argument_result(err)
  117. try:
  118. if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
  119. return get_error_operating_result(message=f"Dataset name '{req['name']}' already exists")
  120. except OperationalError as e:
  121. logging.exception(e)
  122. return get_error_data_result(message="Database operation failed")
  123. req["parser_config"] = get_parser_config(req["parser_id"], req["parser_config"])
  124. req["id"] = get_uuid()
  125. req["tenant_id"] = tenant_id
  126. req["created_by"] = tenant_id
  127. try:
  128. ok, t = TenantService.get_by_id(tenant_id)
  129. if not ok:
  130. return get_error_permission_result(message="Tenant not found")
  131. except OperationalError as e:
  132. logging.exception(e)
  133. return get_error_data_result(message="Database operation failed")
  134. if not req.get("embd_id"):
  135. req["embd_id"] = t.embd_id
  136. else:
  137. ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
  138. if not ok:
  139. return err
  140. try:
  141. if not KnowledgebaseService.save(**req):
  142. return get_error_data_result(message="Create dataset error.(Database error)")
  143. except OperationalError as e:
  144. logging.exception(e)
  145. return get_error_data_result(message="Database operation failed")
  146. try:
  147. ok, k = KnowledgebaseService.get_by_id(req["id"])
  148. if not ok:
  149. return get_error_data_result(message="Dataset created failed")
  150. except OperationalError as e:
  151. logging.exception(e)
  152. return get_error_data_result(message="Database operation failed")
  153. response_data = remap_dictionary_keys(k.to_dict())
  154. return get_result(data=response_data)
  155. @manager.route("/datasets", methods=["DELETE"]) # noqa: F821
  156. @token_required
  157. def delete(tenant_id):
  158. """
  159. Delete datasets.
  160. ---
  161. tags:
  162. - Datasets
  163. security:
  164. - ApiKeyAuth: []
  165. parameters:
  166. - in: header
  167. name: Authorization
  168. type: string
  169. required: true
  170. description: Bearer token for authentication.
  171. - in: body
  172. name: body
  173. description: Dataset deletion parameters.
  174. required: true
  175. schema:
  176. type: object
  177. required:
  178. - ids
  179. properties:
  180. ids:
  181. type: array or null
  182. items:
  183. type: string
  184. description: |
  185. Specifies the datasets to delete:
  186. - If `null`, all datasets will be deleted.
  187. - If an array of IDs, only the specified datasets will be deleted.
  188. - If an empty array, no datasets will be deleted.
  189. responses:
  190. 200:
  191. description: Successful operation.
  192. schema:
  193. type: object
  194. """
  195. req, err = validate_and_parse_json_request(request, DeleteDatasetReq)
  196. if err is not None:
  197. return get_error_argument_result(err)
  198. kb_id_instance_pairs = []
  199. if req["ids"] is None:
  200. try:
  201. kbs = KnowledgebaseService.query(tenant_id=tenant_id)
  202. for kb in kbs:
  203. kb_id_instance_pairs.append((kb.id, kb))
  204. except OperationalError as e:
  205. logging.exception(e)
  206. return get_error_data_result(message="Database operation failed")
  207. else:
  208. error_kb_ids = []
  209. for kb_id in req["ids"]:
  210. try:
  211. kb = KnowledgebaseService.get_or_none(id=kb_id, tenant_id=tenant_id)
  212. if kb is None:
  213. error_kb_ids.append(kb_id)
  214. continue
  215. kb_id_instance_pairs.append((kb_id, kb))
  216. except OperationalError as e:
  217. logging.exception(e)
  218. return get_error_data_result(message="Database operation failed")
  219. if len(error_kb_ids) > 0:
  220. return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
  221. errors = []
  222. success_count = 0
  223. for kb_id, kb in kb_id_instance_pairs:
  224. try:
  225. for doc in DocumentService.query(kb_id=kb_id):
  226. if not DocumentService.remove_document(doc, tenant_id):
  227. errors.append(f"Remove document '{doc.id}' error for dataset '{kb_id}'")
  228. continue
  229. f2d = File2DocumentService.get_by_document_id(doc.id)
  230. FileService.filter_delete(
  231. [
  232. File.source_type == FileSource.KNOWLEDGEBASE,
  233. File.id == f2d[0].file_id,
  234. ]
  235. )
  236. File2DocumentService.delete_by_document_id(doc.id)
  237. FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
  238. if not KnowledgebaseService.delete_by_id(kb_id):
  239. errors.append(f"Delete dataset error for {kb_id}")
  240. continue
  241. success_count += 1
  242. except OperationalError as e:
  243. logging.exception(e)
  244. return get_error_data_result(message="Database operation failed")
  245. if not errors:
  246. return get_result()
  247. error_message = f"Successfully deleted {success_count} datasets, {len(errors)} failed. Details: {'; '.join(errors)[:128]}..."
  248. if success_count == 0:
  249. return get_error_data_result(message=error_message)
  250. return get_result(data={"success_count": success_count, "errors": errors[:5]}, message=error_message)
  251. @manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
  252. @token_required
  253. def update(tenant_id, dataset_id):
  254. """
  255. Update a dataset.
  256. ---
  257. tags:
  258. - Datasets
  259. security:
  260. - ApiKeyAuth: []
  261. parameters:
  262. - in: path
  263. name: dataset_id
  264. type: string
  265. required: true
  266. description: ID of the dataset to update.
  267. - in: header
  268. name: Authorization
  269. type: string
  270. required: true
  271. description: Bearer token for authentication.
  272. - in: body
  273. name: body
  274. description: Dataset update parameters.
  275. required: true
  276. schema:
  277. type: object
  278. properties:
  279. name:
  280. type: string
  281. description: New name of the dataset.
  282. avatar:
  283. type: string
  284. description: Updated base64 encoding of the avatar.
  285. description:
  286. type: string
  287. description: Updated description of the dataset.
  288. embedding_model:
  289. type: string
  290. description: Updated embedding model Name.
  291. permission:
  292. type: string
  293. enum: ['me', 'team']
  294. description: Updated dataset permission.
  295. chunk_method:
  296. type: string
  297. enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
  298. "picture", "presentation", "qa", "table", "tag"
  299. ]
  300. description: Updated chunking method.
  301. pagerank:
  302. type: integer
  303. description: Updated page rank.
  304. parser_config:
  305. type: object
  306. description: Updated parser configuration.
  307. responses:
  308. 200:
  309. description: Successful operation.
  310. schema:
  311. type: object
  312. """
  313. # Field name transformations during model dump:
  314. # | Original | Dump Output |
  315. # |----------------|-------------|
  316. # | embedding_model| embd_id |
  317. # | chunk_method | parser_id |
  318. extras = {"dataset_id": dataset_id}
  319. req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
  320. if err is not None:
  321. return get_error_argument_result(err)
  322. if not req:
  323. return get_error_argument_result(message="No properties were modified")
  324. try:
  325. kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
  326. if kb is None:
  327. return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
  328. except OperationalError as e:
  329. logging.exception(e)
  330. return get_error_data_result(message="Database operation failed")
  331. if req.get("parser_config"):
  332. req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"])
  333. if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id:
  334. if not req.get("parser_config"):
  335. req["parser_config"] = get_parser_config(chunk_method, None)
  336. elif "parser_config" in req and not req["parser_config"]:
  337. del req["parser_config"]
  338. if "name" in req and req["name"].lower() != kb.name.lower():
  339. try:
  340. exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
  341. if exists:
  342. return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
  343. except OperationalError as e:
  344. logging.exception(e)
  345. return get_error_data_result(message="Database operation failed")
  346. if "embd_id" in req:
  347. if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
  348. return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
  349. ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
  350. if not ok:
  351. return err
  352. try:
  353. if not KnowledgebaseService.update_by_id(kb.id, req):
  354. return get_error_data_result(message="Update dataset error.(Database error)")
  355. except OperationalError as e:
  356. logging.exception(e)
  357. return get_error_data_result(message="Database operation failed")
  358. return get_result()
  359. @manager.route("/datasets", methods=["GET"]) # noqa: F821
  360. @token_required
  361. def list_datasets(tenant_id):
  362. """
  363. List datasets.
  364. ---
  365. tags:
  366. - Datasets
  367. security:
  368. - ApiKeyAuth: []
  369. parameters:
  370. - in: query
  371. name: id
  372. type: string
  373. required: false
  374. description: Dataset ID to filter.
  375. - in: query
  376. name: name
  377. type: string
  378. required: false
  379. description: Dataset name to filter.
  380. - in: query
  381. name: page
  382. type: integer
  383. required: false
  384. default: 1
  385. description: Page number.
  386. - in: query
  387. name: page_size
  388. type: integer
  389. required: false
  390. default: 30
  391. description: Number of items per page.
  392. - in: query
  393. name: orderby
  394. type: string
  395. required: false
  396. default: "create_time"
  397. description: Field to order by.
  398. - in: query
  399. name: desc
  400. type: boolean
  401. required: false
  402. default: true
  403. description: Order in descending.
  404. - in: header
  405. name: Authorization
  406. type: string
  407. required: true
  408. description: Bearer token for authentication.
  409. responses:
  410. 200:
  411. description: Successful operation.
  412. schema:
  413. type: array
  414. items:
  415. type: object
  416. """
  417. args, err = validate_and_parse_request_args(request, ListDatasetReq)
  418. if err is not None:
  419. return get_error_argument_result(err)
  420. kb_id = request.args.get("id")
  421. name = args.get("name")
  422. if kb_id:
  423. try:
  424. kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id)
  425. except OperationalError as e:
  426. logging.exception(e)
  427. return get_error_data_result(message="Database operation failed")
  428. if not kbs:
  429. return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'")
  430. if name:
  431. try:
  432. kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
  433. except OperationalError as e:
  434. logging.exception(e)
  435. return get_error_data_result(message="Database operation failed")
  436. if not kbs:
  437. return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
  438. try:
  439. tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
  440. kbs = KnowledgebaseService.get_list(
  441. [m["tenant_id"] for m in tenants],
  442. tenant_id,
  443. args["page"],
  444. args["page_size"],
  445. args["orderby"],
  446. args["desc"],
  447. kb_id,
  448. name,
  449. )
  450. except OperationalError as e:
  451. logging.exception(e)
  452. return get_error_data_result(message="Database operation failed")
  453. response_data_list = []
  454. for kb in kbs:
  455. response_data_list.append(remap_dictionary_keys(kb))
  456. return get_result(data=response_data_list)