Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

doc.py 28KB


  1. import pathlib
  2. import re
  3. import datetime
  4. import json
  5. import traceback
  6. from flask import request
  7. from flask_login import login_required, current_user
  8. from elasticsearch_dsl import Q
  9. from rag.app.qa import rmPrefix, beAdoc
  10. from rag.nlp import search, rag_tokenizer, keyword_extraction
  11. from rag.utils.es_conn import ELASTICSEARCH
  12. from rag.utils import rmSpace
  13. from api.db import LLMType, ParserType
  14. from api.db.services.knowledgebase_service import KnowledgebaseService
  15. from api.db.services.llm_service import TenantLLMService
  16. from api.db.services.user_service import UserTenantService
  17. from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
  18. from api.db.services.document_service import DocumentService
  19. from api.settings import RetCode, retrievaler, kg_retrievaler
  20. from api.utils.api_utils import get_json_result
  21. import hashlib
  22. import re
  23. from api.utils.api_utils import get_json_result, token_required, get_data_error_result
  24. from api.db.db_models import Task, File
  25. from api.db.services.task_service import TaskService, queue_tasks
  26. from api.db.services.user_service import TenantService, UserTenantService
  27. from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
  28. from api.utils.api_utils import get_json_result
  29. from functools import partial
  30. from io import BytesIO
  31. from elasticsearch_dsl import Q
  32. from flask import request, send_file
  33. from flask_login import login_required
  34. from api.db import FileSource, TaskStatus, FileType
  35. from api.db.db_models import File
  36. from api.db.services.document_service import DocumentService
  37. from api.db.services.file2document_service import File2DocumentService
  38. from api.db.services.file_service import FileService
  39. from api.db.services.knowledgebase_service import KnowledgebaseService
  40. from api.settings import RetCode, retrievaler
  41. from api.utils.api_utils import construct_json_result, construct_error_response
  42. from rag.app import book, laws, manual, naive, one, paper, presentation, qa, resume, table, picture, audio, email
  43. from rag.nlp import search
  44. from rag.utils import rmSpace
  45. from rag.utils.es_conn import ELASTICSEARCH
  46. from rag.utils.storage_factory import STORAGE_IMPL
  47. MAXIMUM_OF_UPLOADING_FILES = 256
  48. MAXIMUM_OF_UPLOADING_FILES = 256
  49. @manager.route('/dataset/<dataset_id>/documents/upload', methods=['POST'])
  50. @token_required
  51. def upload(dataset_id, tenant_id):
  52. if 'file' not in request.files:
  53. return get_json_result(
  54. data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
  55. file_objs = request.files.getlist('file')
  56. for file_obj in file_objs:
  57. if file_obj.filename == '':
  58. return get_json_result(
  59. data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
  60. e, kb = KnowledgebaseService.get_by_id(dataset_id)
  61. if not e:
  62. raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!")
  63. err, _ = FileService.upload_document(kb, file_objs, tenant_id)
  64. if err:
  65. return get_json_result(
  66. data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
  67. return get_json_result(data=True)
  68. @manager.route('/infos', methods=['GET'])
  69. @token_required
  70. def docinfos(tenant_id):
  71. req = request.args
  72. if "id" not in req and "name" not in req:
  73. return get_data_error_result(
  74. retmsg="Id or name should be provided")
  75. doc_id=None
  76. if "id" in req:
  77. doc_id = req["id"]
  78. if "name" in req:
  79. doc_name = req["name"]
  80. doc_id = DocumentService.get_doc_id_by_doc_name(doc_name)
  81. e, doc = DocumentService.get_by_id(doc_id)
  82. #rename key's name
  83. key_mapping = {
  84. "chunk_num": "chunk_count",
  85. "kb_id": "knowledgebase_id",
  86. "token_num": "token_count",
  87. }
  88. renamed_doc = {}
  89. for key, value in doc.to_dict().items():
  90. new_key = key_mapping.get(key, key)
  91. renamed_doc[new_key] = value
  92. return get_json_result(data=renamed_doc)
  93. @manager.route('/save', methods=['POST'])
  94. @token_required
  95. def save_doc(tenant_id):
  96. req = request.json
  97. #get doc by id or name
  98. doc_id = None
  99. if "id" in req:
  100. doc_id = req["id"]
  101. elif "name" in req:
  102. doc_name = req["name"]
  103. doc_id = DocumentService.get_doc_id_by_doc_name(doc_name)
  104. if not doc_id:
  105. return get_json_result(retcode=400, retmsg="Document ID or name is required")
  106. e, doc = DocumentService.get_by_id(doc_id)
  107. if not e:
  108. return get_data_error_result(retmsg="Document not found!")
  109. #other value can't be changed
  110. if "chunk_num" in req:
  111. if req["chunk_num"] != doc.chunk_num:
  112. return get_data_error_result(
  113. retmsg="Can't change chunk_count.")
  114. if "progress" in req:
  115. if req['progress'] != doc.progress:
  116. return get_data_error_result(
  117. retmsg="Can't change progress.")
  118. #change name or parse_method
  119. if "name" in req and req["name"] != doc.name:
  120. try:
  121. if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
  122. doc.name.lower()).suffix:
  123. return get_json_result(
  124. data=False,
  125. retmsg="The extension of file can't be changed",
  126. retcode=RetCode.ARGUMENT_ERROR)
  127. for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
  128. if d.name == req["name"]:
  129. return get_data_error_result(
  130. retmsg="Duplicated document name in the same knowledgebase.")
  131. if not DocumentService.update_by_id(
  132. doc_id, {"name": req["name"]}):
  133. return get_data_error_result(
  134. retmsg="Database error (Document rename)!")
  135. informs = File2DocumentService.get_by_document_id(doc_id)
  136. if informs:
  137. e, file = FileService.get_by_id(informs[0].file_id)
  138. FileService.update_by_id(file.id, {"name": req["name"]})
  139. except Exception as e:
  140. return server_error_response(e)
  141. if "parser_id" in req:
  142. try:
  143. if doc.parser_id.lower() == req["parser_id"].lower():
  144. if "parser_config" in req:
  145. if req["parser_config"] == doc.parser_config:
  146. return get_json_result(data=True)
  147. else:
  148. return get_json_result(data=True)
  149. if doc.type == FileType.VISUAL or re.search(
  150. r"\.(ppt|pptx|pages)$", doc.name):
  151. return get_data_error_result(retmsg="Not supported yet!")
  152. e = DocumentService.update_by_id(doc.id,
  153. {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "",
  154. "run": TaskStatus.UNSTART.value})
  155. if not e:
  156. return get_data_error_result(retmsg="Document not found!")
  157. if "parser_config" in req:
  158. DocumentService.update_parser_config(doc.id, req["parser_config"])
  159. if doc.token_num > 0:
  160. e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
  161. doc.process_duation * -1)
  162. if not e:
  163. return get_data_error_result(retmsg="Document not found!")
  164. tenant_id = DocumentService.get_tenant_id(req["doc_id"])
  165. if not tenant_id:
  166. return get_data_error_result(retmsg="Tenant not found!")
  167. ELASTICSEARCH.deleteByQuery(
  168. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  169. except Exception as e:
  170. return server_error_response(e)
  171. return get_json_result(data=True)
  172. @manager.route('/change_parser', methods=['POST'])
  173. @token_required
  174. def change_parser(tenant_id):
  175. req = request.json
  176. try:
  177. e, doc = DocumentService.get_by_id(req["doc_id"])
  178. if not e:
  179. return get_data_error_result(retmsg="Document not found!")
  180. if doc.parser_id.lower() == req["parser_id"].lower():
  181. if "parser_config" in req:
  182. if req["parser_config"] == doc.parser_config:
  183. return get_json_result(data=True)
  184. else:
  185. return get_json_result(data=True)
  186. if doc.type == FileType.VISUAL or re.search(
  187. r"\.(ppt|pptx|pages)$", doc.name):
  188. return get_data_error_result(retmsg="Not supported yet!")
  189. e = DocumentService.update_by_id(doc.id,
  190. {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "",
  191. "run": TaskStatus.UNSTART.value})
  192. if not e:
  193. return get_data_error_result(retmsg="Document not found!")
  194. if "parser_config" in req:
  195. DocumentService.update_parser_config(doc.id, req["parser_config"])
  196. if doc.token_num > 0:
  197. e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
  198. doc.process_duation * -1)
  199. if not e:
  200. return get_data_error_result(retmsg="Document not found!")
  201. tenant_id = DocumentService.get_tenant_id(req["doc_id"])
  202. if not tenant_id:
  203. return get_data_error_result(retmsg="Tenant not found!")
  204. ELASTICSEARCH.deleteByQuery(
  205. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  206. return get_json_result(data=True)
  207. except Exception as e:
  208. return server_error_response(e)
  209. @manager.route('/rename', methods=['POST'])
  210. @login_required
  211. @validate_request("doc_id", "name")
  212. def rename():
  213. req = request.json
  214. try:
  215. e, doc = DocumentService.get_by_id(req["doc_id"])
  216. if not e:
  217. return get_data_error_result(retmsg="Document not found!")
  218. if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
  219. doc.name.lower()).suffix:
  220. return get_json_result(
  221. data=False,
  222. retmsg="The extension of file can't be changed",
  223. retcode=RetCode.ARGUMENT_ERROR)
  224. for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
  225. if d.name == req["name"]:
  226. return get_data_error_result(
  227. retmsg="Duplicated document name in the same knowledgebase.")
  228. if not DocumentService.update_by_id(
  229. req["doc_id"], {"name": req["name"]}):
  230. return get_data_error_result(
  231. retmsg="Database error (Document rename)!")
  232. informs = File2DocumentService.get_by_document_id(req["doc_id"])
  233. if informs:
  234. e, file = FileService.get_by_id(informs[0].file_id)
  235. FileService.update_by_id(file.id, {"name": req["name"]})
  236. return get_json_result(data=True)
  237. except Exception as e:
  238. return server_error_response(e)
  239. @manager.route("/<document_id>", methods=["GET"])
  240. @token_required
  241. def download_document(dataset_id, document_id,tenant_id):
  242. try:
  243. # Check whether there is this document
  244. exist, document = DocumentService.get_by_id(document_id)
  245. if not exist:
  246. return construct_json_result(message=f"This document '{document_id}' cannot be found!",
  247. code=RetCode.ARGUMENT_ERROR)
  248. # The process of downloading
  249. doc_id, doc_location = File2DocumentService.get_minio_address(doc_id=document_id) # minio address
  250. file_stream = STORAGE_IMPL.get(doc_id, doc_location)
  251. if not file_stream:
  252. return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
  253. file = BytesIO(file_stream)
  254. # Use send_file with a proper filename and MIME type
  255. return send_file(
  256. file,
  257. as_attachment=True,
  258. download_name=document.name,
  259. mimetype='application/octet-stream' # Set a default MIME type
  260. )
  261. # Error
  262. except Exception as e:
  263. return construct_error_response(e)
  264. @manager.route('/dataset/<dataset_id>/documents', methods=['GET'])
  265. @token_required
  266. def list_docs(dataset_id, tenant_id):
  267. kb_id = request.args.get("kb_id")
  268. if not kb_id:
  269. return get_json_result(
  270. data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
  271. tenants = UserTenantService.query(user_id=tenant_id)
  272. for tenant in tenants:
  273. if KnowledgebaseService.query(
  274. tenant_id=tenant.tenant_id, id=kb_id):
  275. break
  276. else:
  277. return get_json_result(
  278. data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
  279. retcode=RetCode.OPERATING_ERROR)
  280. keywords = request.args.get("keywords", "")
  281. page_number = int(request.args.get("page", 1))
  282. items_per_page = int(request.args.get("page_size", 15))
  283. orderby = request.args.get("orderby", "create_time")
  284. desc = request.args.get("desc", True)
  285. try:
  286. docs, tol = DocumentService.get_by_kb_id(
  287. kb_id, page_number, items_per_page, orderby, desc, keywords)
  288. # rename key's name
  289. renamed_doc_list = []
  290. for doc in docs:
  291. key_mapping = {
  292. "chunk_num": "chunk_count",
  293. "kb_id": "knowledgebase_id",
  294. "token_num": "token_count",
  295. }
  296. renamed_doc = {}
  297. for key, value in doc.items():
  298. new_key = key_mapping.get(key, key)
  299. renamed_doc[new_key] = value
  300. renamed_doc_list.append(renamed_doc)
  301. return get_json_result(data={"total": tol, "docs": renamed_doc_list})
  302. except Exception as e:
  303. return server_error_response(e)
  304. @manager.route('/delete', methods=['DELETE'])
  305. @token_required
  306. def rm(tenant_id):
  307. req = request.args
  308. if "doc_id" not in req:
  309. return get_data_error_result(
  310. retmsg="doc_id is required")
  311. doc_ids = req["doc_id"]
  312. if isinstance(doc_ids, str): doc_ids = [doc_ids]
  313. root_folder = FileService.get_root_folder(tenant_id)
  314. pf_id = root_folder["id"]
  315. FileService.init_knowledgebase_docs(pf_id, tenant_id)
  316. errors = ""
  317. for doc_id in doc_ids:
  318. try:
  319. e, doc = DocumentService.get_by_id(doc_id)
  320. if not e:
  321. return get_data_error_result(retmsg="Document not found!")
  322. tenant_id = DocumentService.get_tenant_id(doc_id)
  323. if not tenant_id:
  324. return get_data_error_result(retmsg="Tenant not found!")
  325. b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
  326. if not DocumentService.remove_document(doc, tenant_id):
  327. return get_data_error_result(
  328. retmsg="Database error (Document removal)!")
  329. f2d = File2DocumentService.get_by_document_id(doc_id)
  330. FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
  331. File2DocumentService.delete_by_document_id(doc_id)
  332. STORAGE_IMPL.rm(b, n)
  333. except Exception as e:
  334. errors += str(e)
  335. if errors:
  336. return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)
  337. return get_json_result(data=True, retmsg="success")
  338. @manager.route("/<document_id>/status", methods=["GET"])
  339. @token_required
  340. def show_parsing_status(tenant_id, document_id):
  341. try:
  342. # valid document
  343. exist, _ = DocumentService.get_by_id(document_id)
  344. if not exist:
  345. return construct_json_result(code=RetCode.DATA_ERROR,
  346. message=f"This document: '{document_id}' is not a valid document.")
  347. _, doc = DocumentService.get_by_id(document_id) # get doc object
  348. doc_attributes = doc.to_dict()
  349. return construct_json_result(
  350. data={"progress": doc_attributes["progress"], "status": TaskStatus(doc_attributes["status"]).name},
  351. code=RetCode.SUCCESS
  352. )
  353. except Exception as e:
  354. return construct_error_response(e)
  355. @manager.route('/run', methods=['POST'])
  356. @token_required
  357. def run(tenant_id):
  358. req = request.json
  359. try:
  360. for id in req["doc_ids"]:
  361. info = {"run": str(req["run"]), "progress": 0}
  362. if str(req["run"]) == TaskStatus.RUNNING.value:
  363. info["progress_msg"] = ""
  364. info["chunk_num"] = 0
  365. info["token_num"] = 0
  366. DocumentService.update_by_id(id, info)
  367. # if str(req["run"]) == TaskStatus.CANCEL.value:
  368. tenant_id = DocumentService.get_tenant_id(id)
  369. if not tenant_id:
  370. return get_data_error_result(retmsg="Tenant not found!")
  371. ELASTICSEARCH.deleteByQuery(
  372. Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
  373. if str(req["run"]) == TaskStatus.RUNNING.value:
  374. TaskService.filter_delete([Task.doc_id == id])
  375. e, doc = DocumentService.get_by_id(id)
  376. doc = doc.to_dict()
  377. doc["tenant_id"] = tenant_id
  378. bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
  379. queue_tasks(doc, bucket, name)
  380. return get_json_result(data=True)
  381. except Exception as e:
  382. return server_error_response(e)
  383. @manager.route('/chunk/list', methods=['POST'])
  384. @token_required
  385. @validate_request("doc_id")
  386. def list_chunk(tenant_id):
  387. req = request.json
  388. doc_id = req["doc_id"]
  389. page = int(req.get("page", 1))
  390. size = int(req.get("size", 30))
  391. question = req.get("keywords", "")
  392. try:
  393. tenant_id = DocumentService.get_tenant_id(req["doc_id"])
  394. if not tenant_id:
  395. return get_data_error_result(retmsg="Tenant not found!")
  396. e, doc = DocumentService.get_by_id(doc_id)
  397. if not e:
  398. return get_data_error_result(retmsg="Document not found!")
  399. query = {
  400. "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
  401. }
  402. if "available_int" in req:
  403. query["available_int"] = int(req["available_int"])
  404. sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
  405. res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
  406. origin_chunks=[]
  407. for id in sres.ids:
  408. d = {
  409. "chunk_id": id,
  410. "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
  411. id].get(
  412. "content_with_weight", ""),
  413. "doc_id": sres.field[id]["doc_id"],
  414. "docnm_kwd": sres.field[id]["docnm_kwd"],
  415. "important_kwd": sres.field[id].get("important_kwd", []),
  416. "img_id": sres.field[id].get("img_id", ""),
  417. "available_int": sres.field[id].get("available_int", 1),
  418. "positions": sres.field[id].get("position_int", "").split("\t")
  419. }
  420. if len(d["positions"]) % 5 == 0:
  421. poss = []
  422. for i in range(0, len(d["positions"]), 5):
  423. poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
  424. float(d["positions"][i + 3]), float(d["positions"][i + 4])])
  425. d["positions"] = poss
  426. origin_chunks.append(d)
  427. ##rename keys
  428. for chunk in origin_chunks:
  429. key_mapping = {
  430. "chunk_id": "id",
  431. "content_with_weight": "content",
  432. "doc_id": "document_id",
  433. "important_kwd": "important_keywords",
  434. }
  435. renamed_chunk = {}
  436. for key, value in chunk.items():
  437. new_key = key_mapping.get(key, key)
  438. renamed_chunk[new_key] = value
  439. res["chunks"].append(renamed_chunk)
  440. return get_json_result(data=res)
  441. except Exception as e:
  442. if str(e).find("not_found") > 0:
  443. return get_json_result(data=False, retmsg=f'No chunk found!',
  444. retcode=RetCode.DATA_ERROR)
  445. return server_error_response(e)
  446. @manager.route('/chunk/create', methods=['POST'])
  447. @token_required
  448. @validate_request("doc_id", "content_with_weight")
  449. def create(tenant_id):
  450. req = request.json
  451. md5 = hashlib.md5()
  452. md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
  453. chunk_id = md5.hexdigest()
  454. d = {"id": chunk_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
  455. "content_with_weight": req["content_with_weight"]}
  456. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  457. d["important_kwd"] = req.get("important_kwd", [])
  458. d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", [])))
  459. d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
  460. d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
  461. try:
  462. e, doc = DocumentService.get_by_id(req["doc_id"])
  463. if not e:
  464. return get_data_error_result(retmsg="Document not found!")
  465. d["kb_id"] = [doc.kb_id]
  466. d["docnm_kwd"] = doc.name
  467. d["doc_id"] = doc.id
  468. tenant_id = DocumentService.get_tenant_id(req["doc_id"])
  469. if not tenant_id:
  470. return get_data_error_result(retmsg="Tenant not found!")
  471. embd_id = DocumentService.get_embd_id(req["doc_id"])
  472. embd_mdl = TenantLLMService.model_instance(
  473. tenant_id, LLMType.EMBEDDING.value, embd_id)
  474. v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
  475. v = 0.1 * v[0] + 0.9 * v[1]
  476. d["q_%d_vec" % len(v)] = v.tolist()
  477. ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
  478. DocumentService.increment_chunk_num(
  479. doc.id, doc.kb_id, c, 1, 0)
  480. d["chunk_id"] = chunk_id
  481. #rename keys
  482. key_mapping = {
  483. "chunk_id": "id",
  484. "content_with_weight": "content",
  485. "doc_id": "document_id",
  486. "important_kwd": "important_keywords",
  487. "kb_id":"knowledge_base_id",
  488. }
  489. renamed_chunk = {}
  490. for key, value in d.items():
  491. new_key = key_mapping.get(key, key)
  492. renamed_chunk[new_key] = value
  493. return get_json_result(data={"chunk": renamed_chunk})
  494. # return get_json_result(data={"chunk_id": chunk_id})
  495. except Exception as e:
  496. return server_error_response(e)
  497. @manager.route('/chunk/rm', methods=['POST'])
  498. @token_required
  499. @validate_request("chunk_ids", "doc_id")
  500. def rm_chunk(tenant_id):
  501. req = request.json
  502. try:
  503. if not ELASTICSEARCH.deleteByQuery(
  504. Q("ids", values=req["chunk_ids"]), search.index_name(tenant_id)):
  505. return get_data_error_result(retmsg="Index updating failure")
  506. e, doc = DocumentService.get_by_id(req["doc_id"])
  507. if not e:
  508. return get_data_error_result(retmsg="Document not found!")
  509. deleted_chunk_ids = req["chunk_ids"]
  510. chunk_number = len(deleted_chunk_ids)
  511. DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
  512. return get_json_result(data=True)
  513. except Exception as e:
  514. return server_error_response(e)
  515. @manager.route('/chunk/set', methods=['POST'])
  516. @token_required
  517. @validate_request("doc_id", "chunk_id", "content_with_weight",
  518. "important_kwd")
  519. def set(tenant_id):
  520. req = request.json
  521. d = {
  522. "id": req["chunk_id"],
  523. "content_with_weight": req["content_with_weight"]}
  524. d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
  525. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  526. d["important_kwd"] = req["important_kwd"]
  527. d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
  528. if "available_int" in req:
  529. d["available_int"] = req["available_int"]
  530. try:
  531. tenant_id = DocumentService.get_tenant_id(req["doc_id"])
  532. if not tenant_id:
  533. return get_data_error_result(retmsg="Tenant not found!")
  534. embd_id = DocumentService.get_embd_id(req["doc_id"])
  535. embd_mdl = TenantLLMService.model_instance(
  536. tenant_id, LLMType.EMBEDDING.value, embd_id)
  537. e, doc = DocumentService.get_by_id(req["doc_id"])
  538. if not e:
  539. return get_data_error_result(retmsg="Document not found!")
  540. if doc.parser_id == ParserType.QA:
  541. arr = [
  542. t for t in re.split(
  543. r"[\n\t]",
  544. req["content_with_weight"]) if len(t) > 1]
  545. if len(arr) != 2:
  546. return get_data_error_result(
  547. retmsg="Q&A must be separated by TAB/ENTER key.")
  548. q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
  549. d = beAdoc(d, arr[0], arr[1], not any(
  550. [rag_tokenizer.is_chinese(t) for t in q + a]))
  551. v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
  552. v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
  553. d["q_%d_vec" % len(v)] = v.tolist()
  554. ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
  555. return get_json_result(data=True)
  556. except Exception as e:
  557. return server_error_response(e)
  558. @manager.route('/retrieval_test', methods=['POST'])
  559. @token_required
  560. @validate_request("kb_id", "question")
  561. def retrieval_test(tenant_id):
  562. req = request.json
  563. page = int(req.get("page", 1))
  564. size = int(req.get("size", 30))
  565. question = req["question"]
  566. kb_id = req["kb_id"]
  567. if isinstance(kb_id, str): kb_id = [kb_id]
  568. doc_ids = req.get("doc_ids", [])
  569. similarity_threshold = float(req.get("similarity_threshold", 0.2))
  570. vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
  571. top = int(req.get("top_k", 1024))
  572. try:
  573. tenants = UserTenantService.query(user_id=tenant_id)
  574. for kid in kb_id:
  575. for tenant in tenants:
  576. if KnowledgebaseService.query(
  577. tenant_id=tenant.tenant_id, id=kid):
  578. break
  579. else:
  580. return get_json_result(
  581. data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
  582. retcode=RetCode.OPERATING_ERROR)
  583. e, kb = KnowledgebaseService.get_by_id(kb_id[0])
  584. if not e:
  585. return get_data_error_result(retmsg="Knowledgebase not found!")
  586. embd_mdl = TenantLLMService.model_instance(
  587. kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
  588. rerank_mdl = None
  589. if req.get("rerank_id"):
  590. rerank_mdl = TenantLLMService.model_instance(
  591. kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
  592. if req.get("keyword", False):
  593. chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
  594. question += keyword_extraction(chat_mdl, question)
  595. retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
  596. ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
  597. similarity_threshold, vector_similarity_weight, top,
  598. doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
  599. for c in ranks["chunks"]:
  600. if "vector" in c:
  601. del c["vector"]
  602. ##rename keys
  603. renamed_chunks=[]
  604. for chunk in ranks["chunks"]:
  605. key_mapping = {
  606. "chunk_id": "id",
  607. "content_with_weight": "content",
  608. "doc_id": "document_id",
  609. "important_kwd": "important_keywords",
  610. }
  611. rename_chunk={}
  612. for key, value in chunk.items():
  613. new_key = key_mapping.get(key, key)
  614. rename_chunk[new_key] = value
  615. renamed_chunks.append(rename_chunk)
  616. ranks["chunks"] = renamed_chunks
  617. return get_json_result(data=ranks)
  618. except Exception as e:
  619. if str(e).find("not_found") > 0:
  620. return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
  621. retcode=RetCode.DATA_ERROR)
  622. return server_error_response(e)