您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. import pathlib
  2. import re
  3. import datetime
  4. import json
  5. import traceback
  6. from flask import request
  7. from flask_login import login_required, current_user
  8. from elasticsearch_dsl import Q
  9. from rag.app.qa import rmPrefix, beAdoc
  10. from rag.nlp import search, rag_tokenizer, keyword_extraction
  11. from rag.utils.es_conn import ELASTICSEARCH
  12. from rag.utils import rmSpace
  13. from api.db import LLMType, ParserType
  14. from api.db.services.knowledgebase_service import KnowledgebaseService
  15. from api.db.services.llm_service import TenantLLMService
  16. from api.db.services.user_service import UserTenantService
  17. from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
  18. from api.db.services.document_service import DocumentService
  19. from api.settings import RetCode, retrievaler, kg_retrievaler
  20. from api.utils.api_utils import get_json_result
  21. import hashlib
  22. import re
  23. from api.utils.api_utils import get_json_result, token_required, get_data_error_result
  24. from api.db.db_models import Task, File
  25. from api.db.services.task_service import TaskService, queue_tasks
  26. from api.db.services.user_service import TenantService, UserTenantService
  27. from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
  28. from api.utils.api_utils import get_json_result
  29. from functools import partial
  30. from io import BytesIO
  31. from elasticsearch_dsl import Q
  32. from flask import request, send_file
  33. from flask_login import login_required
  34. from api.db import FileSource, TaskStatus, FileType
  35. from api.db.db_models import File
  36. from api.db.services.document_service import DocumentService
  37. from api.db.services.file2document_service import File2DocumentService
  38. from api.db.services.file_service import FileService
  39. from api.db.services.knowledgebase_service import KnowledgebaseService
  40. from api.settings import RetCode, retrievaler
  41. from api.utils.api_utils import construct_json_result, construct_error_response
  42. from rag.app import book, laws, manual, naive, one, paper, presentation, qa, resume, table, picture, audio, email
  43. from rag.nlp import search
  44. from rag.utils import rmSpace
  45. from rag.utils.es_conn import ELASTICSEARCH
  46. from rag.utils.storage_factory import STORAGE_IMPL
  47. MAXIMUM_OF_UPLOADING_FILES = 256
  48. MAXIMUM_OF_UPLOADING_FILES = 256
  49. @manager.route('/dataset/<dataset_id>/documents/upload', methods=['POST'])
  50. @token_required
  51. def upload(dataset_id, tenant_id):
  52. if 'file' not in request.files:
  53. return get_json_result(
  54. data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
  55. file_objs = request.files.getlist('file')
  56. for file_obj in file_objs:
  57. if file_obj.filename == '':
  58. return get_json_result(
  59. data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
  60. e, kb = KnowledgebaseService.get_by_id(dataset_id)
  61. if not e:
  62. raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!")
  63. err, _ = FileService.upload_document(kb, file_objs, tenant_id)
  64. if err:
  65. return get_json_result(
  66. data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
  67. return get_json_result(data=True)
  68. @manager.route('/infos', methods=['GET'])
  69. @token_required
  70. def docinfos(tenant_id):
  71. req = request.args
  72. if "id" in req:
  73. doc_id = req["id"]
  74. e, doc = DocumentService.get_by_id(doc_id)
  75. return get_json_result(data=doc.to_json())
  76. if "name" in req:
  77. doc_name = req["name"]
  78. doc_id = DocumentService.get_doc_id_by_doc_name(doc_name)
  79. e, doc = DocumentService.get_by_id(doc_id)
  80. return get_json_result(data=doc.to_json())
  81. @manager.route('/save', methods=['POST'])
  82. @token_required
  83. def save_doc(tenant_id):
  84. req = request.json
  85. #get doc by id or name
  86. doc_id = None
  87. if "id" in req:
  88. doc_id = req["id"]
  89. elif "name" in req:
  90. doc_name = req["name"]
  91. doc_id = DocumentService.get_doc_id_by_doc_name(doc_name)
  92. if not doc_id:
  93. return get_json_result(retcode=400, retmsg="Document ID or name is required")
  94. e, doc = DocumentService.get_by_id(doc_id)
  95. if not e:
  96. return get_data_error_result(retmsg="Document not found!")
  97. #other value can't be changed
  98. if "chunk_num" in req:
  99. if req["chunk_num"] != doc.chunk_num:
  100. return get_data_error_result(
  101. retmsg="Can't change chunk_count.")
  102. if "progress" in req:
  103. if req['progress'] != doc.progress:
  104. return get_data_error_result(
  105. retmsg="Can't change progress.")
  106. #change name or parse_method
  107. if "name" in req and req["name"] != doc.name:
  108. try:
  109. if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
  110. doc.name.lower()).suffix:
  111. return get_json_result(
  112. data=False,
  113. retmsg="The extension of file can't be changed",
  114. retcode=RetCode.ARGUMENT_ERROR)
  115. for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
  116. if d.name == req["name"]:
  117. return get_data_error_result(
  118. retmsg="Duplicated document name in the same knowledgebase.")
  119. if not DocumentService.update_by_id(
  120. doc_id, {"name": req["name"]}):
  121. return get_data_error_result(
  122. retmsg="Database error (Document rename)!")
  123. informs = File2DocumentService.get_by_document_id(doc_id)
  124. if informs:
  125. e, file = FileService.get_by_id(informs[0].file_id)
  126. FileService.update_by_id(file.id, {"name": req["name"]})
  127. except Exception as e:
  128. return server_error_response(e)
  129. if "parser_id" in req:
  130. try:
  131. if doc.parser_id.lower() == req["parser_id"].lower():
  132. if "parser_config" in req:
  133. if req["parser_config"] == doc.parser_config:
  134. return get_json_result(data=True)
  135. else:
  136. return get_json_result(data=True)
  137. if doc.type == FileType.VISUAL or re.search(
  138. r"\.(ppt|pptx|pages)$", doc.name):
  139. return get_data_error_result(retmsg="Not supported yet!")
  140. e = DocumentService.update_by_id(doc.id,
  141. {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "",
  142. "run": TaskStatus.UNSTART.value})
  143. if not e:
  144. return get_data_error_result(retmsg="Document not found!")
  145. if "parser_config" in req:
  146. DocumentService.update_parser_config(doc.id, req["parser_config"])
  147. if doc.token_num > 0:
  148. e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
  149. doc.process_duation * -1)
  150. if not e:
  151. return get_data_error_result(retmsg="Document not found!")
  152. tenant_id = DocumentService.get_tenant_id(req["doc_id"])
  153. if not tenant_id:
  154. return get_data_error_result(retmsg="Tenant not found!")
  155. ELASTICSEARCH.deleteByQuery(
  156. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  157. except Exception as e:
  158. return server_error_response(e)
  159. return get_json_result(data=True)
  160. @manager.route('/change_parser', methods=['POST'])
  161. @token_required
  162. def change_parser(tenant_id):
  163. req = request.json
  164. try:
  165. e, doc = DocumentService.get_by_id(req["doc_id"])
  166. if not e:
  167. return get_data_error_result(retmsg="Document not found!")
  168. if doc.parser_id.lower() == req["parser_id"].lower():
  169. if "parser_config" in req:
  170. if req["parser_config"] == doc.parser_config:
  171. return get_json_result(data=True)
  172. else:
  173. return get_json_result(data=True)
  174. if doc.type == FileType.VISUAL or re.search(
  175. r"\.(ppt|pptx|pages)$", doc.name):
  176. return get_data_error_result(retmsg="Not supported yet!")
  177. e = DocumentService.update_by_id(doc.id,
  178. {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "",
  179. "run": TaskStatus.UNSTART.value})
  180. if not e:
  181. return get_data_error_result(retmsg="Document not found!")
  182. if "parser_config" in req:
  183. DocumentService.update_parser_config(doc.id, req["parser_config"])
  184. if doc.token_num > 0:
  185. e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
  186. doc.process_duation * -1)
  187. if not e:
  188. return get_data_error_result(retmsg="Document not found!")
  189. tenant_id = DocumentService.get_tenant_id(req["doc_id"])
  190. if not tenant_id:
  191. return get_data_error_result(retmsg="Tenant not found!")
  192. ELASTICSEARCH.deleteByQuery(
  193. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  194. return get_json_result(data=True)
  195. except Exception as e:
  196. return server_error_response(e)
  197. @manager.route('/rename', methods=['POST'])
  198. @login_required
  199. @validate_request("doc_id", "name")
  200. def rename():
  201. req = request.json
  202. try:
  203. e, doc = DocumentService.get_by_id(req["doc_id"])
  204. if not e:
  205. return get_data_error_result(retmsg="Document not found!")
  206. if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
  207. doc.name.lower()).suffix:
  208. return get_json_result(
  209. data=False,
  210. retmsg="The extension of file can't be changed",
  211. retcode=RetCode.ARGUMENT_ERROR)
  212. for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
  213. if d.name == req["name"]:
  214. return get_data_error_result(
  215. retmsg="Duplicated document name in the same knowledgebase.")
  216. if not DocumentService.update_by_id(
  217. req["doc_id"], {"name": req["name"]}):
  218. return get_data_error_result(
  219. retmsg="Database error (Document rename)!")
  220. informs = File2DocumentService.get_by_document_id(req["doc_id"])
  221. if informs:
  222. e, file = FileService.get_by_id(informs[0].file_id)
  223. FileService.update_by_id(file.id, {"name": req["name"]})
  224. return get_json_result(data=True)
  225. except Exception as e:
  226. return server_error_response(e)
  227. @manager.route("/<document_id>", methods=["GET"])
  228. @token_required
  229. def download_document(dataset_id, document_id):
  230. try:
  231. # Check whether there is this document
  232. exist, document = DocumentService.get_by_id(document_id)
  233. if not exist:
  234. return construct_json_result(message=f"This document '{document_id}' cannot be found!",
  235. code=RetCode.ARGUMENT_ERROR)
  236. # The process of downloading
  237. doc_id, doc_location = File2DocumentService.get_minio_address(doc_id=document_id) # minio address
  238. file_stream = STORAGE_IMPL.get(doc_id, doc_location)
  239. if not file_stream:
  240. return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
  241. file = BytesIO(file_stream)
  242. # Use send_file with a proper filename and MIME type
  243. return send_file(
  244. file,
  245. as_attachment=True,
  246. download_name=document.name,
  247. mimetype='application/octet-stream' # Set a default MIME type
  248. )
  249. # Error
  250. except Exception as e:
  251. return construct_error_response(e)
  252. @manager.route('/dataset/<dataset_id>/documents', methods=['GET'])
  253. @token_required
  254. def list_docs(dataset_id, tenant_id):
  255. kb_id = request.args.get("kb_id")
  256. if not kb_id:
  257. return get_json_result(
  258. data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
  259. tenants = UserTenantService.query(user_id=tenant_id)
  260. for tenant in tenants:
  261. if KnowledgebaseService.query(
  262. tenant_id=tenant.tenant_id, id=kb_id):
  263. break
  264. else:
  265. return get_json_result(
  266. data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
  267. retcode=RetCode.OPERATING_ERROR)
  268. keywords = request.args.get("keywords", "")
  269. page_number = int(request.args.get("page", 1))
  270. items_per_page = int(request.args.get("page_size", 15))
  271. orderby = request.args.get("orderby", "create_time")
  272. desc = request.args.get("desc", True)
  273. try:
  274. docs, tol = DocumentService.get_by_kb_id(
  275. kb_id, page_number, items_per_page, orderby, desc, keywords)
  276. return get_json_result(data={"total": tol, "docs": docs})
  277. except Exception as e:
  278. return server_error_response(e)
  279. @manager.route('/delete', methods=['DELETE'])
  280. @token_required
  281. def rm(tenant_id):
  282. req = request.args
  283. if "doc_id" not in req:
  284. return get_data_error_result(
  285. retmsg="doc_id is required")
  286. doc_ids = req["doc_id"]
  287. if isinstance(doc_ids, str): doc_ids = [doc_ids]
  288. root_folder = FileService.get_root_folder(tenant_id)
  289. pf_id = root_folder["id"]
  290. FileService.init_knowledgebase_docs(pf_id, tenant_id)
  291. errors = ""
  292. for doc_id in doc_ids:
  293. try:
  294. e, doc = DocumentService.get_by_id(doc_id)
  295. if not e:
  296. return get_data_error_result(retmsg="Document not found!")
  297. tenant_id = DocumentService.get_tenant_id(doc_id)
  298. if not tenant_id:
  299. return get_data_error_result(retmsg="Tenant not found!")
  300. b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
  301. if not DocumentService.remove_document(doc, tenant_id):
  302. return get_data_error_result(
  303. retmsg="Database error (Document removal)!")
  304. f2d = File2DocumentService.get_by_document_id(doc_id)
  305. FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
  306. File2DocumentService.delete_by_document_id(doc_id)
  307. STORAGE_IMPL.rm(b, n)
  308. except Exception as e:
  309. errors += str(e)
  310. if errors:
  311. return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)
  312. return get_json_result(data=True, retmsg="success")
  313. @manager.route("/<document_id>/status", methods=["GET"])
  314. @token_required
  315. def show_parsing_status(tenant_id, document_id):
  316. try:
  317. # valid document
  318. exist, _ = DocumentService.get_by_id(document_id)
  319. if not exist:
  320. return construct_json_result(code=RetCode.DATA_ERROR,
  321. message=f"This document: '{document_id}' is not a valid document.")
  322. _, doc = DocumentService.get_by_id(document_id) # get doc object
  323. doc_attributes = doc.to_dict()
  324. return construct_json_result(
  325. data={"progress": doc_attributes["progress"], "status": TaskStatus(doc_attributes["status"]).name},
  326. code=RetCode.SUCCESS
  327. )
  328. except Exception as e:
  329. return construct_error_response(e)
  330. @manager.route('/run', methods=['POST'])
  331. @token_required
  332. def run(tenant_id):
  333. req = request.json
  334. try:
  335. for id in req["doc_ids"]:
  336. info = {"run": str(req["run"]), "progress": 0}
  337. if str(req["run"]) == TaskStatus.RUNNING.value:
  338. info["progress_msg"] = ""
  339. info["chunk_num"] = 0
  340. info["token_num"] = 0
  341. DocumentService.update_by_id(id, info)
  342. # if str(req["run"]) == TaskStatus.CANCEL.value:
  343. tenant_id = DocumentService.get_tenant_id(id)
  344. if not tenant_id:
  345. return get_data_error_result(retmsg="Tenant not found!")
  346. ELASTICSEARCH.deleteByQuery(
  347. Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
  348. if str(req["run"]) == TaskStatus.RUNNING.value:
  349. TaskService.filter_delete([Task.doc_id == id])
  350. e, doc = DocumentService.get_by_id(id)
  351. doc = doc.to_dict()
  352. doc["tenant_id"] = tenant_id
  353. bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
  354. queue_tasks(doc, bucket, name)
  355. return get_json_result(data=True)
  356. except Exception as e:
  357. return server_error_response(e)
  358. @manager.route('/chunk/list', methods=['POST'])
  359. @token_required
  360. @validate_request("doc_id")
  361. def list_chunk(tenant_id):
  362. req = request.json
  363. doc_id = req["doc_id"]
  364. page = int(req.get("page", 1))
  365. size = int(req.get("size", 30))
  366. question = req.get("keywords", "")
  367. try:
  368. tenant_id = DocumentService.get_tenant_id(req["doc_id"])
  369. if not tenant_id:
  370. return get_data_error_result(retmsg="Tenant not found!")
  371. e, doc = DocumentService.get_by_id(doc_id)
  372. if not e:
  373. return get_data_error_result(retmsg="Document not found!")
  374. query = {
  375. "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
  376. }
  377. if "available_int" in req:
  378. query["available_int"] = int(req["available_int"])
  379. sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
  380. res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
  381. for id in sres.ids:
  382. d = {
  383. "chunk_id": id,
  384. "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
  385. id].get(
  386. "content_with_weight", ""),
  387. "doc_id": sres.field[id]["doc_id"],
  388. "docnm_kwd": sres.field[id]["docnm_kwd"],
  389. "important_kwd": sres.field[id].get("important_kwd", []),
  390. "img_id": sres.field[id].get("img_id", ""),
  391. "available_int": sres.field[id].get("available_int", 1),
  392. "positions": sres.field[id].get("position_int", "").split("\t")
  393. }
  394. if len(d["positions"]) % 5 == 0:
  395. poss = []
  396. for i in range(0, len(d["positions"]), 5):
  397. poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
  398. float(d["positions"][i + 3]), float(d["positions"][i + 4])])
  399. d["positions"] = poss
  400. res["chunks"].append(d)
  401. return get_json_result(data=res)
  402. except Exception as e:
  403. if str(e).find("not_found") > 0:
  404. return get_json_result(data=False, retmsg=f'No chunk found!',
  405. retcode=RetCode.DATA_ERROR)
  406. return server_error_response(e)
  407. @manager.route('/chunk/create', methods=['POST'])
  408. @token_required
  409. @validate_request("doc_id", "content_with_weight")
  410. def create(tenant_id):
  411. req = request.json
  412. md5 = hashlib.md5()
  413. md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
  414. chunck_id = md5.hexdigest()
  415. d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
  416. "content_with_weight": req["content_with_weight"]}
  417. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  418. d["important_kwd"] = req.get("important_kwd", [])
  419. d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", [])))
  420. d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
  421. d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
  422. try:
  423. e, doc = DocumentService.get_by_id(req["doc_id"])
  424. if not e:
  425. return get_data_error_result(retmsg="Document not found!")
  426. d["kb_id"] = [doc.kb_id]
  427. d["docnm_kwd"] = doc.name
  428. d["doc_id"] = doc.id
  429. tenant_id = DocumentService.get_tenant_id(req["doc_id"])
  430. if not tenant_id:
  431. return get_data_error_result(retmsg="Tenant not found!")
  432. embd_id = DocumentService.get_embd_id(req["doc_id"])
  433. embd_mdl = TenantLLMService.model_instance(
  434. tenant_id, LLMType.EMBEDDING.value, embd_id)
  435. v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
  436. v = 0.1 * v[0] + 0.9 * v[1]
  437. d["q_%d_vec" % len(v)] = v.tolist()
  438. ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
  439. DocumentService.increment_chunk_num(
  440. doc.id, doc.kb_id, c, 1, 0)
  441. return get_json_result(data={"chunk": d})
  442. # return get_json_result(data={"chunk_id": chunck_id})
  443. except Exception as e:
  444. return server_error_response(e)
  445. @manager.route('/chunk/rm', methods=['POST'])
  446. @token_required
  447. @validate_request("chunk_ids", "doc_id")
  448. def rm_chunk():
  449. req = request.json
  450. try:
  451. if not ELASTICSEARCH.deleteByQuery(
  452. Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
  453. return get_data_error_result(retmsg="Index updating failure")
  454. e, doc = DocumentService.get_by_id(req["doc_id"])
  455. if not e:
  456. return get_data_error_result(retmsg="Document not found!")
  457. deleted_chunk_ids = req["chunk_ids"]
  458. chunk_number = len(deleted_chunk_ids)
  459. DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
  460. return get_json_result(data=True)
  461. except Exception as e:
  462. return server_error_response(e)