選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

doc.py 25KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. import pathlib
  2. import re
  3. import datetime
  4. import json
  5. import traceback
  6. from botocore.docs.method import document_model_driven_method
  7. from flask import request
  8. from flask_login import login_required, current_user
  9. from elasticsearch_dsl import Q
  10. from pygments import highlight
  11. from sphinx.addnodes import document
  12. from rag.app.qa import rmPrefix, beAdoc
  13. from rag.nlp import search, rag_tokenizer, keyword_extraction
  14. from rag.utils.es_conn import ELASTICSEARCH
  15. from rag.utils import rmSpace
  16. from api.db import LLMType, ParserType
  17. from api.db.services.knowledgebase_service import KnowledgebaseService
  18. from api.db.services.llm_service import TenantLLMService
  19. from api.db.services.user_service import UserTenantService
  20. from api.utils.api_utils import server_error_response, get_error_data_result, validate_request
  21. from api.db.services.document_service import DocumentService
  22. from api.settings import RetCode, retrievaler, kg_retrievaler
  23. from api.utils.api_utils import get_result
  24. import hashlib
  25. import re
  26. from api.utils.api_utils import get_result, token_required, get_error_data_result
  27. from api.db.db_models import Task, File
  28. from api.db.services.task_service import TaskService, queue_tasks
  29. from api.db.services.user_service import TenantService, UserTenantService
  30. from api.utils.api_utils import server_error_response, get_error_data_result, validate_request
  31. from api.utils.api_utils import get_result, get_result, get_error_data_result
  32. from functools import partial
  33. from io import BytesIO
  34. from elasticsearch_dsl import Q
  35. from flask import request, send_file
  36. from flask_login import login_required
  37. from api.db import FileSource, TaskStatus, FileType
  38. from api.db.db_models import File
  39. from api.db.services.document_service import DocumentService
  40. from api.db.services.file2document_service import File2DocumentService
  41. from api.db.services.file_service import FileService
  42. from api.db.services.knowledgebase_service import KnowledgebaseService
  43. from api.settings import RetCode, retrievaler
  44. from api.utils.api_utils import construct_json_result, construct_error_response
  45. from rag.app import book, laws, manual, naive, one, paper, presentation, qa, resume, table, picture, audio, email
  46. from rag.nlp import search
  47. from rag.utils import rmSpace
  48. from rag.utils.es_conn import ELASTICSEARCH
  49. from rag.utils.storage_factory import STORAGE_IMPL
  50. MAXIMUM_OF_UPLOADING_FILES = 256
  51. MAXIMUM_OF_UPLOADING_FILES = 256
  52. @manager.route('/dataset/<dataset_id>/document', methods=['POST'])
  53. @token_required
  54. def upload(dataset_id, tenant_id):
  55. if 'file' not in request.files:
  56. return get_error_data_result(
  57. retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
  58. file_objs = request.files.getlist('file')
  59. for file_obj in file_objs:
  60. if file_obj.filename == '':
  61. return get_result(
  62. retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
  63. e, kb = KnowledgebaseService.get_by_id(dataset_id)
  64. if not e:
  65. raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!")
  66. err, _ = FileService.upload_document(kb, file_objs, tenant_id)
  67. if err:
  68. return get_result(
  69. retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
  70. return get_result()
  71. @manager.route('/dataset/<dataset_id>/info/<document_id>', methods=['PUT'])
  72. @token_required
  73. def update_doc(tenant_id, dataset_id, document_id):
  74. req = request.json
  75. if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
  76. return get_error_data_result(retmsg="You don't own the dataset.")
  77. doc = DocumentService.query(kb_id=dataset_id, id=document_id)
  78. if not doc:
  79. return get_error_data_result(retmsg="The dataset doesn't own the document.")
  80. doc = doc[0]
  81. if "chunk_count" in req:
  82. if req["chunk_count"] != doc.chunk_num:
  83. return get_error_data_result(retmsg="Can't change `chunk_count`.")
  84. if "token_count" in req:
  85. if req["token_count"] != doc.token_num:
  86. return get_error_data_result(retmsg="Can't change `token_count`.")
  87. if "progress" in req:
  88. if req['progress'] != doc.progress:
  89. return get_error_data_result(retmsg="Can't change `progress`.")
  90. if "name" in req and req["name"] != doc.name:
  91. if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
  92. return get_result(retmsg="The extension of file can't be changed", retcode=RetCode.ARGUMENT_ERROR)
  93. for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
  94. if d.name == req["name"]:
  95. return get_error_data_result(
  96. retmsg="Duplicated document name in the same knowledgebase.")
  97. if not DocumentService.update_by_id(
  98. document_id, {"name": req["name"]}):
  99. return get_error_data_result(
  100. retmsg="Database error (Document rename)!")
  101. informs = File2DocumentService.get_by_document_id(document_id)
  102. if informs:
  103. e, file = FileService.get_by_id(informs[0].file_id)
  104. FileService.update_by_id(file.id, {"name": req["name"]})
  105. if "parser_config" in req:
  106. DocumentService.update_parser_config(doc.id, req["parser_config"])
  107. if "chunk_method" in req:
  108. if doc.parser_id.lower() == req["chunk_method"].lower():
  109. return get_result()
  110. if doc.type == FileType.VISUAL or re.search(
  111. r"\.(ppt|pptx|pages)$", doc.name):
  112. return get_error_data_result(retmsg="Not supported yet!")
  113. e = DocumentService.update_by_id(doc.id,
  114. {"parser_id": req["chunk_method"], "progress": 0, "progress_msg": "",
  115. "run": TaskStatus.UNSTART.value})
  116. if not e:
  117. return get_error_data_result(retmsg="Document not found!")
  118. if doc.token_num > 0:
  119. e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
  120. doc.process_duation * -1)
  121. if not e:
  122. return get_error_data_result(retmsg="Document not found!")
  123. tenant_id = DocumentService.get_tenant_id(req["id"])
  124. if not tenant_id:
  125. return get_error_data_result(retmsg="Tenant not found!")
  126. ELASTICSEARCH.deleteByQuery(
  127. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  128. return get_result()
  129. @manager.route('/dataset/<dataset_id>/document/<document_id>', methods=['GET'])
  130. @token_required
  131. def download(tenant_id, dataset_id, document_id):
  132. if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
  133. return get_error_data_result(retmsg=f'You do not own the dataset {dataset_id}.')
  134. doc = DocumentService.query(kb_id=dataset_id, id=document_id)
  135. if not doc:
  136. return get_error_data_result(retmsg=f'The dataset not own the document {document_id}.')
  137. # The process of downloading
  138. doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address
  139. file_stream = STORAGE_IMPL.get(doc_id, doc_location)
  140. if not file_stream:
  141. return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
  142. file = BytesIO(file_stream)
  143. # Use send_file with a proper filename and MIME type
  144. return send_file(
  145. file,
  146. as_attachment=True,
  147. download_name=doc[0].name,
  148. mimetype='application/octet-stream' # Set a default MIME type
  149. )
  150. @manager.route('/dataset/<dataset_id>/info', methods=['GET'])
  151. @token_required
  152. def list_docs(dataset_id, tenant_id):
  153. if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
  154. return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}. ")
  155. id = request.args.get("id")
  156. if not DocumentService.query(id=id,kb_id=dataset_id):
  157. return get_error_data_result(retmsg=f"You don't own the document {id}.")
  158. offset = int(request.args.get("offset", 1))
  159. keywords = request.args.get("keywords","")
  160. limit = int(request.args.get("limit", 1024))
  161. orderby = request.args.get("orderby", "create_time")
  162. if request.args.get("desc") == "False":
  163. desc = False
  164. else:
  165. desc = True
  166. docs, tol = DocumentService.get_list(dataset_id, offset, limit, orderby, desc, keywords, id)
  167. # rename key's name
  168. renamed_doc_list = []
  169. for doc in docs:
  170. key_mapping = {
  171. "chunk_num": "chunk_count",
  172. "kb_id": "knowledgebase_id",
  173. "token_num": "token_count",
  174. "parser_id": "chunk_method"
  175. }
  176. renamed_doc = {}
  177. for key, value in doc.items():
  178. new_key = key_mapping.get(key, key)
  179. renamed_doc[new_key] = value
  180. renamed_doc_list.append(renamed_doc)
  181. return get_result(data={"total": tol, "docs": renamed_doc_list})
  182. @manager.route('/dataset/<dataset_id>/document', methods=['DELETE'])
  183. @token_required
  184. def delete(tenant_id,dataset_id):
  185. if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
  186. return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}. ")
  187. req = request.json
  188. if not req.get("ids"):
  189. return get_error_data_result(retmsg="`ids` is required")
  190. doc_ids = req["ids"]
  191. root_folder = FileService.get_root_folder(tenant_id)
  192. pf_id = root_folder["id"]
  193. FileService.init_knowledgebase_docs(pf_id, tenant_id)
  194. errors = ""
  195. for doc_id in doc_ids:
  196. try:
  197. e, doc = DocumentService.get_by_id(doc_id)
  198. if not e:
  199. return get_error_data_result(retmsg="Document not found!")
  200. tenant_id = DocumentService.get_tenant_id(doc_id)
  201. if not tenant_id:
  202. return get_error_data_result(retmsg="Tenant not found!")
  203. b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
  204. if not DocumentService.remove_document(doc, tenant_id):
  205. return get_error_data_result(
  206. retmsg="Database error (Document removal)!")
  207. f2d = File2DocumentService.get_by_document_id(doc_id)
  208. FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
  209. File2DocumentService.delete_by_document_id(doc_id)
  210. STORAGE_IMPL.rm(b, n)
  211. except Exception as e:
  212. errors += str(e)
  213. if errors:
  214. return get_result(retmsg=errors, retcode=RetCode.SERVER_ERROR)
  215. return get_result()
  216. @manager.route('/dataset/<dataset_id>/chunk', methods=['POST'])
  217. @token_required
  218. def parse(tenant_id,dataset_id):
  219. if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
  220. return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.")
  221. req = request.json
  222. if not req.get("document_ids"):
  223. return get_error_data_result("`document_ids` is required")
  224. for id in req["document_ids"]:
  225. if not DocumentService.query(id=id,kb_id=dataset_id):
  226. return get_error_data_result(retmsg=f"You don't own the document {id}.")
  227. info = {"run": "1", "progress": 0}
  228. info["progress_msg"] = ""
  229. info["chunk_num"] = 0
  230. info["token_num"] = 0
  231. DocumentService.update_by_id(id, info)
  232. # if str(req["run"]) == TaskStatus.CANCEL.value:
  233. ELASTICSEARCH.deleteByQuery(
  234. Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
  235. TaskService.filter_delete([Task.doc_id == id])
  236. e, doc = DocumentService.get_by_id(id)
  237. doc = doc.to_dict()
  238. doc["tenant_id"] = tenant_id
  239. bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
  240. queue_tasks(doc, bucket, name)
  241. return get_result()
  242. @manager.route('/dataset/<dataset_id>/chunk', methods=['DELETE'])
  243. @token_required
  244. def stop_parsing(tenant_id,dataset_id):
  245. if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
  246. return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.")
  247. req = request.json
  248. if not req.get("document_ids"):
  249. return get_error_data_result("`document_ids` is required")
  250. for id in req["document_ids"]:
  251. doc = DocumentService.query(id=id, kb_id=dataset_id)
  252. if not doc:
  253. return get_error_data_result(retmsg=f"You don't own the document {id}.")
  254. if doc[0].progress == 100.0 or doc[0].progress == 0.0:
  255. return get_error_data_result("Can't stop parsing document with progress at 0 or 100")
  256. info = {"run": "2", "progress": 0}
  257. DocumentService.update_by_id(id, info)
  258. # if str(req["run"]) == TaskStatus.CANCEL.value:
  259. tenant_id = DocumentService.get_tenant_id(id)
  260. ELASTICSEARCH.deleteByQuery(
  261. Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
  262. return get_result()
  263. @manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['GET'])
  264. @token_required
  265. def list_chunks(tenant_id,dataset_id,document_id):
  266. if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
  267. return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.")
  268. doc=DocumentService.query(id=document_id, kb_id=dataset_id)
  269. if not doc:
  270. return get_error_data_result(retmsg=f"You don't own the document {document_id}.")
  271. doc=doc[0]
  272. req = request.args
  273. doc_id = document_id
  274. page = int(req.get("offset", 1))
  275. size = int(req.get("limit", 30))
  276. question = req.get("keywords", "")
  277. query = {
  278. "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
  279. }
  280. sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
  281. res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
  282. origin_chunks = []
  283. sign = 0
  284. for id in sres.ids:
  285. d = {
  286. "chunk_id": id,
  287. "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
  288. id].get(
  289. "content_with_weight", ""),
  290. "doc_id": sres.field[id]["doc_id"],
  291. "docnm_kwd": sres.field[id]["docnm_kwd"],
  292. "important_kwd": sres.field[id].get("important_kwd", []),
  293. "img_id": sres.field[id].get("img_id", ""),
  294. "available_int": sres.field[id].get("available_int", 1),
  295. "positions": sres.field[id].get("position_int", "").split("\t")
  296. }
  297. if len(d["positions"]) % 5 == 0:
  298. poss = []
  299. for i in range(0, len(d["positions"]), 5):
  300. poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
  301. float(d["positions"][i + 3]), float(d["positions"][i + 4])])
  302. d["positions"] = poss
  303. origin_chunks.append(d)
  304. if req.get("id"):
  305. if req.get("id") == id:
  306. origin_chunks.clear()
  307. origin_chunks.append(d)
  308. sign = 1
  309. break
  310. if req.get("id"):
  311. if sign == 0:
  312. return get_error_data_result(f"Can't find this chunk {req.get('id')}")
  313. for chunk in origin_chunks:
  314. key_mapping = {
  315. "chunk_id": "id",
  316. "content_with_weight": "content",
  317. "doc_id": "document_id",
  318. "important_kwd": "important_keywords",
  319. "img_id": "image_id",
  320. }
  321. renamed_chunk = {}
  322. for key, value in chunk.items():
  323. new_key = key_mapping.get(key, key)
  324. renamed_chunk[new_key] = value
  325. res["chunks"].append(renamed_chunk)
  326. return get_result(data=res)
  327. @manager.route('/dataset/<dataset_id>/document/<document_id>/chunk', methods=['POST'])
  328. @token_required
  329. def create(tenant_id,dataset_id,document_id):
  330. if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
  331. return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.")
  332. doc = DocumentService.query(id=document_id, kb_id=dataset_id)
  333. if not doc:
  334. return get_error_data_result(retmsg=f"You don't own the document {document_id}.")
  335. doc = doc[0]
  336. req = request.json
  337. if not req.get("content"):
  338. return get_error_data_result(retmsg="`content` is required")
  339. if "important_keywords" in req:
  340. if type(req["important_keywords"]) != list:
  341. return get_error_data_result("`important_keywords` is required to be a list")
  342. md5 = hashlib.md5()
  343. md5.update((req["content"] + document_id).encode("utf-8"))
  344. chunk_id = md5.hexdigest()
  345. d = {"id": chunk_id, "content_ltks": rag_tokenizer.tokenize(req["content"]),
  346. "content_with_weight": req["content"]}
  347. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  348. d["important_kwd"] = req.get("important_keywords", [])
  349. d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_keywords", [])))
  350. d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
  351. d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
  352. d["kb_id"] = [doc.kb_id]
  353. d["docnm_kwd"] = doc.name
  354. d["doc_id"] = doc.id
  355. embd_id = DocumentService.get_embd_id(document_id)
  356. embd_mdl = TenantLLMService.model_instance(
  357. tenant_id, LLMType.EMBEDDING.value, embd_id)
  358. v, c = embd_mdl.encode([doc.name, req["content"]])
  359. v = 0.1 * v[0] + 0.9 * v[1]
  360. d["q_%d_vec" % len(v)] = v.tolist()
  361. ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
  362. DocumentService.increment_chunk_num(
  363. doc.id, doc.kb_id, c, 1, 0)
  364. d["chunk_id"] = chunk_id
  365. # rename keys
  366. key_mapping = {
  367. "chunk_id": "id",
  368. "content_with_weight": "content",
  369. "doc_id": "document_id",
  370. "important_kwd": "important_keywords",
  371. "kb_id": "dataset_id",
  372. "create_timestamp_flt": "create_timestamp",
  373. "create_time": "create_time",
  374. "document_keyword": "document",
  375. }
  376. renamed_chunk = {}
  377. for key, value in d.items():
  378. if key in key_mapping:
  379. new_key = key_mapping.get(key, key)
  380. renamed_chunk[new_key] = value
  381. return get_result(data={"chunk": renamed_chunk})
  382. # return get_result(data={"chunk_id": chunk_id})
  383. @manager.route('dataset/<dataset_id>/document/<document_id>/chunk', methods=['DELETE'])
  384. @token_required
  385. def rm_chunk(tenant_id,dataset_id,document_id):
  386. if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
  387. return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.")
  388. doc = DocumentService.query(id=document_id, kb_id=dataset_id)
  389. if not doc:
  390. return get_error_data_result(retmsg=f"You don't own the document {document_id}.")
  391. doc = doc[0]
  392. req = request.json
  393. if not req.get("chunk_ids"):
  394. return get_error_data_result("`chunk_ids` is required")
  395. query = {
  396. "doc_ids": [doc.id], "page": 1, "size": 1024, "question": "", "sort": True}
  397. sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
  398. for chunk_id in req.get("chunk_ids"):
  399. if chunk_id not in sres.ids:
  400. return get_error_data_result(f"Chunk {chunk_id} not found")
  401. if not ELASTICSEARCH.deleteByQuery(
  402. Q("ids", values=req["chunk_ids"]), search.index_name(tenant_id)):
  403. return get_error_data_result(retmsg="Index updating failure")
  404. deleted_chunk_ids = req["chunk_ids"]
  405. chunk_number = len(deleted_chunk_ids)
  406. DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
  407. return get_result()
  408. @manager.route('/dataset/<dataset_id>/document/<document_id>/chunk/<chunk_id>', methods=['PUT'])
  409. @token_required
  410. def update_chunk(tenant_id,dataset_id,document_id,chunk_id):
  411. try:
  412. res = ELASTICSEARCH.get(
  413. chunk_id, search.index_name(
  414. tenant_id))
  415. except Exception as e:
  416. return get_error_data_result(f"Can't find this chunk {chunk_id}")
  417. if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
  418. return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.")
  419. doc = DocumentService.query(id=document_id, kb_id=dataset_id)
  420. if not doc:
  421. return get_error_data_result(retmsg=f"You don't own the document {document_id}.")
  422. doc = doc[0]
  423. query = {
  424. "doc_ids": [document_id], "page": 1, "size": 1024, "question": "", "sort": True
  425. }
  426. sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
  427. if chunk_id not in sres.ids:
  428. return get_error_data_result(f"You don't own the chunk {chunk_id}")
  429. req = request.json
  430. content=res["_source"].get("content_with_weight")
  431. d = {
  432. "id": chunk_id,
  433. "content_with_weight": req.get("content",content)}
  434. d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
  435. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  436. if "important_keywords" in req:
  437. if type(req["important_keywords"]) != list:
  438. return get_error_data_result("`important_keywords` is required to be a list")
  439. d["important_kwd"] = req.get("important_keywords")
  440. d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"]))
  441. if "available" in req:
  442. d["available_int"] = req["available"]
  443. embd_id = DocumentService.get_embd_id(document_id)
  444. embd_mdl = TenantLLMService.model_instance(
  445. tenant_id, LLMType.EMBEDDING.value, embd_id)
  446. if doc.parser_id == ParserType.QA:
  447. arr = [
  448. t for t in re.split(
  449. r"[\n\t]",
  450. d["content_with_weight"]) if len(t) > 1]
  451. if len(arr) != 2:
  452. return get_error_data_result(
  453. retmsg="Q&A must be separated by TAB/ENTER key.")
  454. q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
  455. d = beAdoc(d, arr[0], arr[1], not any(
  456. [rag_tokenizer.is_chinese(t) for t in q + a]))
  457. v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
  458. v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
  459. d["q_%d_vec" % len(v)] = v.tolist()
  460. ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
  461. return get_result()
  462. @manager.route('/retrieval', methods=['POST'])
  463. @token_required
  464. def retrieval_test(tenant_id):
  465. req = request.json
  466. if not req.get("datasets"):
  467. return get_error_data_result("`datasets` is required.")
  468. kb_ids = req["datasets"]
  469. kbs = KnowledgebaseService.get_by_ids(kb_ids)
  470. embd_nms = list(set([kb.embd_id for kb in kbs]))
  471. if len(embd_nms) != 1:
  472. return get_result(
  473. retmsg='Knowledge bases use different embedding models or does not exist."',
  474. retcode=RetCode.AUTHENTICATION_ERROR)
  475. if isinstance(kb_ids, str): kb_ids = [kb_ids]
  476. for id in kb_ids:
  477. if not KnowledgebaseService.query(id=id,tenant_id=tenant_id):
  478. return get_error_data_result(f"You don't own the dataset {id}.")
  479. if "question" not in req:
  480. return get_error_data_result("`question` is required.")
  481. page = int(req.get("offset", 1))
  482. size = int(req.get("limit", 30))
  483. question = req["question"]
  484. doc_ids = req.get("documents", [])
  485. similarity_threshold = float(req.get("similarity_threshold", 0.2))
  486. vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
  487. top = int(req.get("top_k", 1024))
  488. if req.get("highlight")=="False" or req.get("highlight")=="false":
  489. highlight = False
  490. else:
  491. highlight = True
  492. try:
  493. e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
  494. if not e:
  495. return get_error_data_result(retmsg="Knowledgebase not found!")
  496. embd_mdl = TenantLLMService.model_instance(
  497. kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
  498. rerank_mdl = None
  499. if req.get("rerank_id"):
  500. rerank_mdl = TenantLLMService.model_instance(
  501. kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
  502. if req.get("keyword", False):
  503. chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
  504. question += keyword_extraction(chat_mdl, question)
  505. retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
  506. ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
  507. similarity_threshold, vector_similarity_weight, top,
  508. doc_ids, rerank_mdl=rerank_mdl, highlight=highlight)
  509. for c in ranks["chunks"]:
  510. if "vector" in c:
  511. del c["vector"]
  512. ##rename keys
  513. renamed_chunks = []
  514. for chunk in ranks["chunks"]:
  515. key_mapping = {
  516. "chunk_id": "id",
  517. "content_with_weight": "content",
  518. "doc_id": "document_id",
  519. "important_kwd": "important_keywords",
  520. "docnm_kwd": "document_keyword"
  521. }
  522. rename_chunk = {}
  523. for key, value in chunk.items():
  524. new_key = key_mapping.get(key, key)
  525. rename_chunk[new_key] = value
  526. renamed_chunks.append(rename_chunk)
  527. ranks["chunks"] = renamed_chunks
  528. return get_result(data=ranks)
  529. except Exception as e:
  530. if str(e).find("not_found") > 0:
  531. return get_result(retmsg=f'No chunk found! Check the chunk status please!',
  532. retcode=RetCode.DATA_ERROR)
  533. return server_error_response(e)