Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

dataset.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import json
  19. from flask import request
  20. from peewee import OperationalError
  21. from api import settings
  22. from api.db import FileSource, StatusEnum
  23. from api.db.db_models import File
  24. from api.db.services.document_service import DocumentService
  25. from api.db.services.file2document_service import File2DocumentService
  26. from api.db.services.file_service import FileService
  27. from api.db.services.knowledgebase_service import KnowledgebaseService
  28. from api.db.services.user_service import TenantService
  29. from api.utils import get_uuid
  30. from api.utils.api_utils import (
  31. deep_merge,
  32. get_error_argument_result,
  33. get_error_data_result,
  34. get_error_operating_result,
  35. get_error_permission_result,
  36. get_parser_config,
  37. get_result,
  38. remap_dictionary_keys,
  39. token_required,
  40. verify_embedding_availability,
  41. )
  42. from api.utils.validation_utils import (
  43. CreateDatasetReq,
  44. DeleteDatasetReq,
  45. ListDatasetReq,
  46. UpdateDatasetReq,
  47. validate_and_parse_json_request,
  48. validate_and_parse_request_args,
  49. )
  50. from rag.nlp import search
  51. from rag.settings import PAGERANK_FLD
  52. @manager.route("/datasets", methods=["POST"]) # noqa: F821
  53. @token_required
  54. def create(tenant_id):
  55. """
  56. Create a new dataset.
  57. ---
  58. tags:
  59. - Datasets
  60. security:
  61. - ApiKeyAuth: []
  62. parameters:
  63. - in: header
  64. name: Authorization
  65. type: string
  66. required: true
  67. description: Bearer token for authentication.
  68. - in: body
  69. name: body
  70. description: Dataset creation parameters.
  71. required: true
  72. schema:
  73. type: object
  74. required:
  75. - name
  76. properties:
  77. name:
  78. type: string
  79. description: Name of the dataset.
  80. avatar:
  81. type: string
  82. description: Base64 encoding of the avatar.
  83. description:
  84. type: string
  85. description: Description of the dataset.
  86. embedding_model:
  87. type: string
  88. description: Embedding model Name.
  89. permission:
  90. type: string
  91. enum: ['me', 'team']
  92. description: Dataset permission.
  93. chunk_method:
  94. type: string
  95. enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
  96. "picture", "presentation", "qa", "table", "tag"
  97. ]
  98. description: Chunking method.
  99. parser_config:
  100. type: object
  101. description: Parser configuration.
  102. responses:
  103. 200:
  104. description: Successful operation.
  105. schema:
  106. type: object
  107. properties:
  108. data:
  109. type: object
  110. """
  111. # Field name transformations during model dump:
  112. # | Original | Dump Output |
  113. # |----------------|-------------|
  114. # | embedding_model| embd_id |
  115. # | chunk_method | parser_id |
  116. req, err = validate_and_parse_json_request(request, CreateDatasetReq)
  117. if err is not None:
  118. return get_error_argument_result(err)
  119. try:
  120. if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
  121. return get_error_operating_result(message=f"Dataset name '{req['name']}' already exists")
  122. req["parser_config"] = get_parser_config(req["parser_id"], req["parser_config"])
  123. req["id"] = get_uuid()
  124. req["tenant_id"] = tenant_id
  125. req["created_by"] = tenant_id
  126. ok, t = TenantService.get_by_id(tenant_id)
  127. if not ok:
  128. return get_error_permission_result(message="Tenant not found")
  129. if not req.get("embd_id"):
  130. req["embd_id"] = t.embd_id
  131. else:
  132. ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
  133. if not ok:
  134. return err
  135. if not KnowledgebaseService.save(**req):
  136. return get_error_data_result(message="Create dataset error.(Database error)")
  137. ok, k = KnowledgebaseService.get_by_id(req["id"])
  138. if not ok:
  139. return get_error_data_result(message="Dataset created failed")
  140. response_data = remap_dictionary_keys(k.to_dict())
  141. return get_result(data=response_data)
  142. except OperationalError as e:
  143. logging.exception(e)
  144. return get_error_data_result(message="Database operation failed")
  145. @manager.route("/datasets", methods=["DELETE"]) # noqa: F821
  146. @token_required
  147. def delete(tenant_id):
  148. """
  149. Delete datasets.
  150. ---
  151. tags:
  152. - Datasets
  153. security:
  154. - ApiKeyAuth: []
  155. parameters:
  156. - in: header
  157. name: Authorization
  158. type: string
  159. required: true
  160. description: Bearer token for authentication.
  161. - in: body
  162. name: body
  163. description: Dataset deletion parameters.
  164. required: true
  165. schema:
  166. type: object
  167. required:
  168. - ids
  169. properties:
  170. ids:
  171. type: array or null
  172. items:
  173. type: string
  174. description: |
  175. Specifies the datasets to delete:
  176. - If `null`, all datasets will be deleted.
  177. - If an array of IDs, only the specified datasets will be deleted.
  178. - If an empty array, no datasets will be deleted.
  179. responses:
  180. 200:
  181. description: Successful operation.
  182. schema:
  183. type: object
  184. """
  185. req, err = validate_and_parse_json_request(request, DeleteDatasetReq)
  186. if err is not None:
  187. return get_error_argument_result(err)
  188. try:
  189. kb_id_instance_pairs = []
  190. if req["ids"] is None:
  191. kbs = KnowledgebaseService.query(tenant_id=tenant_id)
  192. for kb in kbs:
  193. kb_id_instance_pairs.append((kb.id, kb))
  194. else:
  195. error_kb_ids = []
  196. for kb_id in req["ids"]:
  197. kb = KnowledgebaseService.get_or_none(id=kb_id, tenant_id=tenant_id)
  198. if kb is None:
  199. error_kb_ids.append(kb_id)
  200. continue
  201. kb_id_instance_pairs.append((kb_id, kb))
  202. if len(error_kb_ids) > 0:
  203. return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
  204. errors = []
  205. success_count = 0
  206. for kb_id, kb in kb_id_instance_pairs:
  207. for doc in DocumentService.query(kb_id=kb_id):
  208. if not DocumentService.remove_document(doc, tenant_id):
  209. errors.append(f"Remove document '{doc.id}' error for dataset '{kb_id}'")
  210. continue
  211. f2d = File2DocumentService.get_by_document_id(doc.id)
  212. FileService.filter_delete(
  213. [
  214. File.source_type == FileSource.KNOWLEDGEBASE,
  215. File.id == f2d[0].file_id,
  216. ]
  217. )
  218. File2DocumentService.delete_by_document_id(doc.id)
  219. FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
  220. if not KnowledgebaseService.delete_by_id(kb_id):
  221. errors.append(f"Delete dataset error for {kb_id}")
  222. continue
  223. success_count += 1
  224. if not errors:
  225. return get_result()
  226. error_message = f"Successfully deleted {success_count} datasets, {len(errors)} failed. Details: {'; '.join(errors)[:128]}..."
  227. if success_count == 0:
  228. return get_error_data_result(message=error_message)
  229. return get_result(data={"success_count": success_count, "errors": errors[:5]}, message=error_message)
  230. except OperationalError as e:
  231. logging.exception(e)
  232. return get_error_data_result(message="Database operation failed")
  233. @manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
  234. @token_required
  235. def update(tenant_id, dataset_id):
  236. """
  237. Update a dataset.
  238. ---
  239. tags:
  240. - Datasets
  241. security:
  242. - ApiKeyAuth: []
  243. parameters:
  244. - in: path
  245. name: dataset_id
  246. type: string
  247. required: true
  248. description: ID of the dataset to update.
  249. - in: header
  250. name: Authorization
  251. type: string
  252. required: true
  253. description: Bearer token for authentication.
  254. - in: body
  255. name: body
  256. description: Dataset update parameters.
  257. required: true
  258. schema:
  259. type: object
  260. properties:
  261. name:
  262. type: string
  263. description: New name of the dataset.
  264. avatar:
  265. type: string
  266. description: Updated base64 encoding of the avatar.
  267. description:
  268. type: string
  269. description: Updated description of the dataset.
  270. embedding_model:
  271. type: string
  272. description: Updated embedding model Name.
  273. permission:
  274. type: string
  275. enum: ['me', 'team']
  276. description: Updated dataset permission.
  277. chunk_method:
  278. type: string
  279. enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
  280. "picture", "presentation", "qa", "table", "tag"
  281. ]
  282. description: Updated chunking method.
  283. pagerank:
  284. type: integer
  285. description: Updated page rank.
  286. parser_config:
  287. type: object
  288. description: Updated parser configuration.
  289. responses:
  290. 200:
  291. description: Successful operation.
  292. schema:
  293. type: object
  294. """
  295. # Field name transformations during model dump:
  296. # | Original | Dump Output |
  297. # |----------------|-------------|
  298. # | embedding_model| embd_id |
  299. # | chunk_method | parser_id |
  300. extras = {"dataset_id": dataset_id}
  301. req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
  302. if err is not None:
  303. return get_error_argument_result(err)
  304. if not req:
  305. return get_error_argument_result(message="No properties were modified")
  306. try:
  307. kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
  308. if kb is None:
  309. return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
  310. if req.get("parser_config"):
  311. req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"])
  312. if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id:
  313. if not req.get("parser_config"):
  314. req["parser_config"] = get_parser_config(chunk_method, None)
  315. elif "parser_config" in req and not req["parser_config"]:
  316. del req["parser_config"]
  317. if "name" in req and req["name"].lower() != kb.name.lower():
  318. exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
  319. if exists:
  320. return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
  321. if "embd_id" in req:
  322. if not req["embd_id"]:
  323. req["embd_id"] = kb.embd_id
  324. if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
  325. return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
  326. ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
  327. if not ok:
  328. return err
  329. if "pagerank" in req and req["pagerank"] != kb.pagerank:
  330. if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity":
  331. return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
  332. if req["pagerank"] > 0:
  333. settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
  334. else:
  335. # Elasticsearch requires PAGERANK_FLD be non-zero!
  336. settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)
  337. if not KnowledgebaseService.update_by_id(kb.id, req):
  338. return get_error_data_result(message="Update dataset error.(Database error)")
  339. ok, k = KnowledgebaseService.get_by_id(kb.id)
  340. if not ok:
  341. return get_error_data_result(message="Dataset created failed")
  342. response_data = remap_dictionary_keys(k.to_dict())
  343. return get_result(data=response_data)
  344. except OperationalError as e:
  345. logging.exception(e)
  346. return get_error_data_result(message="Database operation failed")
  347. @manager.route("/datasets", methods=["GET"]) # noqa: F821
  348. @token_required
  349. def list_datasets(tenant_id):
  350. """
  351. List datasets.
  352. ---
  353. tags:
  354. - Datasets
  355. security:
  356. - ApiKeyAuth: []
  357. parameters:
  358. - in: query
  359. name: id
  360. type: string
  361. required: false
  362. description: Dataset ID to filter.
  363. - in: query
  364. name: name
  365. type: string
  366. required: false
  367. description: Dataset name to filter.
  368. - in: query
  369. name: page
  370. type: integer
  371. required: false
  372. default: 1
  373. description: Page number.
  374. - in: query
  375. name: page_size
  376. type: integer
  377. required: false
  378. default: 30
  379. description: Number of items per page.
  380. - in: query
  381. name: orderby
  382. type: string
  383. required: false
  384. default: "create_time"
  385. description: Field to order by.
  386. - in: query
  387. name: desc
  388. type: boolean
  389. required: false
  390. default: true
  391. description: Order in descending.
  392. - in: header
  393. name: Authorization
  394. type: string
  395. required: true
  396. description: Bearer token for authentication.
  397. responses:
  398. 200:
  399. description: Successful operation.
  400. schema:
  401. type: array
  402. items:
  403. type: object
  404. """
  405. args, err = validate_and_parse_request_args(request, ListDatasetReq)
  406. if err is not None:
  407. return get_error_argument_result(err)
  408. try:
  409. kb_id = request.args.get("id")
  410. name = args.get("name")
  411. if kb_id:
  412. kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id)
  413. if not kbs:
  414. return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'")
  415. if name:
  416. kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
  417. if not kbs:
  418. return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
  419. tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
  420. kbs = KnowledgebaseService.get_list(
  421. [m["tenant_id"] for m in tenants],
  422. tenant_id,
  423. args["page"],
  424. args["page_size"],
  425. args["orderby"],
  426. args["desc"],
  427. kb_id,
  428. name,
  429. )
  430. response_data_list = []
  431. for kb in kbs:
  432. response_data_list.append(remap_dictionary_keys(kb))
  433. return get_result(data=response_data_list)
  434. except OperationalError as e:
  435. logging.exception(e)
  436. return get_error_data_result(message="Database operation failed")
  437. @manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
  438. @token_required
  439. def knowledge_graph(tenant_id,dataset_id):
  440. if not KnowledgebaseService.accessible(dataset_id, tenant_id):
  441. return get_result(
  442. data=False,
  443. message='No authorization.',
  444. code=settings.RetCode.AUTHENTICATION_ERROR
  445. )
  446. _, kb = KnowledgebaseService.get_by_id(dataset_id)
  447. req = {
  448. "kb_id": [dataset_id],
  449. "knowledge_graph_kwd": ["graph"]
  450. }
  451. obj = {"graph": {}, "mind_map": {}}
  452. if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
  453. return get_result(data=obj)
  454. sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [dataset_id])
  455. if not len(sres.ids):
  456. return get_result(data=obj)
  457. for id in sres.ids[:1]:
  458. ty = sres.field[id]["knowledge_graph_kwd"]
  459. try:
  460. content_json = json.loads(sres.field[id]["content_with_weight"])
  461. except Exception:
  462. continue
  463. obj[ty] = content_json
  464. if "nodes" in obj["graph"]:
  465. obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
  466. if "edges" in obj["graph"]:
  467. node_id_set = { o["id"] for o in obj["graph"]["nodes"] }
  468. filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
  469. obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
  470. return get_result(data=obj)
  471. @manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['DELETE']) # noqa: F821
  472. @token_required
  473. def delete_knowledge_graph(tenant_id,dataset_id):
  474. if not KnowledgebaseService.accessible(dataset_id, tenant_id):
  475. return get_result(
  476. data=False,
  477. message='No authorization.',
  478. code=settings.RetCode.AUTHENTICATION_ERROR
  479. )
  480. _, kb = KnowledgebaseService.get_by_id(dataset_id)
  481. settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id)
  482. return get_result(data=True)