You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

document_app.py 22KB


  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License
  15. #
  16. import datetime
  17. import hashlib
  18. import json
  19. import os
  20. import pathlib
  21. import re
  22. import traceback
  23. from concurrent.futures import ThreadPoolExecutor
  24. from copy import deepcopy
  25. from io import BytesIO
  26. import flask
  27. from elasticsearch_dsl import Q
  28. from flask import request
  29. from flask_login import login_required, current_user
  30. from api.db.db_models import Task, File
  31. from api.db.services.dialog_service import DialogService, ConversationService
  32. from api.db.services.file2document_service import File2DocumentService
  33. from api.db.services.file_service import FileService
  34. from api.db.services.llm_service import LLMBundle
  35. from api.db.services.task_service import TaskService, queue_tasks
  36. from api.db.services.user_service import TenantService
  37. from graphrag.mind_map_extractor import MindMapExtractor
  38. from rag.app import naive
  39. from rag.nlp import search
  40. from rag.utils.es_conn import ELASTICSEARCH
  41. from api.db.services import duplicate_name
  42. from api.db.services.knowledgebase_service import KnowledgebaseService
  43. from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
  44. from api.utils import get_uuid
  45. from api.db import FileType, TaskStatus, ParserType, FileSource, LLMType
  46. from api.db.services.document_service import DocumentService
  47. from api.settings import RetCode, stat_logger
  48. from api.utils.api_utils import get_json_result
  49. from rag.utils.minio_conn import MINIO
  50. from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
  51. from api.utils.web_utils import html2pdf, is_valid_url
  52. @manager.route('/upload', methods=['POST'])
  53. @login_required
  54. @validate_request("kb_id")
  55. def upload():
  56. kb_id = request.form.get("kb_id")
  57. if not kb_id:
  58. return get_json_result(
  59. data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
  60. if 'file' not in request.files:
  61. return get_json_result(
  62. data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
  63. file_objs = request.files.getlist('file')
  64. for file_obj in file_objs:
  65. if file_obj.filename == '':
  66. return get_json_result(
  67. data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
  68. e, kb = KnowledgebaseService.get_by_id(kb_id)
  69. if not e:
  70. raise LookupError("Can't find this knowledgebase!")
  71. err, _ = FileService.upload_document(kb, file_objs)
  72. if err:
  73. return get_json_result(
  74. data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
  75. return get_json_result(data=True)
  76. @manager.route('/web_crawl', methods=['POST'])
  77. @login_required
  78. @validate_request("kb_id", "name", "url")
  79. def web_crawl():
  80. kb_id = request.form.get("kb_id")
  81. if not kb_id:
  82. return get_json_result(
  83. data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
  84. name = request.form.get("name")
  85. url = request.form.get("url")
  86. if not is_valid_url(url):
  87. return get_json_result(
  88. data=False, retmsg='The URL format is invalid', retcode=RetCode.ARGUMENT_ERROR)
  89. e, kb = KnowledgebaseService.get_by_id(kb_id)
  90. if not e:
  91. raise LookupError("Can't find this knowledgebase!")
  92. blob = html2pdf(url)
  93. if not blob: return server_error_response(ValueError("Download failure."))
  94. root_folder = FileService.get_root_folder(current_user.id)
  95. pf_id = root_folder["id"]
  96. FileService.init_knowledgebase_docs(pf_id, current_user.id)
  97. kb_root_folder = FileService.get_kb_folder(current_user.id)
  98. kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
  99. try:
  100. filename = duplicate_name(
  101. DocumentService.query,
  102. name=name + ".pdf",
  103. kb_id=kb.id)
  104. filetype = filename_type(filename)
  105. if filetype == FileType.OTHER.value:
  106. raise RuntimeError("This type of file has not been supported yet!")
  107. location = filename
  108. while MINIO.obj_exist(kb_id, location):
  109. location += "_"
  110. MINIO.put(kb_id, location, blob)
  111. doc = {
  112. "id": get_uuid(),
  113. "kb_id": kb.id,
  114. "parser_id": kb.parser_id,
  115. "parser_config": kb.parser_config,
  116. "created_by": current_user.id,
  117. "type": filetype,
  118. "name": filename,
  119. "location": location,
  120. "size": len(blob),
  121. "thumbnail": thumbnail(filename, blob)
  122. }
  123. if doc["type"] == FileType.VISUAL:
  124. doc["parser_id"] = ParserType.PICTURE.value
  125. if doc["type"] == FileType.AURAL:
  126. doc["parser_id"] = ParserType.AUDIO.value
  127. if re.search(r"\.(ppt|pptx|pages)$", filename):
  128. doc["parser_id"] = ParserType.PRESENTATION.value
  129. DocumentService.insert(doc)
  130. FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
  131. except Exception as e:
  132. return server_error_response(e)
  133. return get_json_result(data=True)
  134. @manager.route('/create', methods=['POST'])
  135. @login_required
  136. @validate_request("name", "kb_id")
  137. def create():
  138. req = request.json
  139. kb_id = req["kb_id"]
  140. if not kb_id:
  141. return get_json_result(
  142. data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
  143. try:
  144. e, kb = KnowledgebaseService.get_by_id(kb_id)
  145. if not e:
  146. return get_data_error_result(
  147. retmsg="Can't find this knowledgebase!")
  148. if DocumentService.query(name=req["name"], kb_id=kb_id):
  149. return get_data_error_result(
  150. retmsg="Duplicated document name in the same knowledgebase.")
  151. doc = DocumentService.insert({
  152. "id": get_uuid(),
  153. "kb_id": kb.id,
  154. "parser_id": kb.parser_id,
  155. "parser_config": kb.parser_config,
  156. "created_by": current_user.id,
  157. "type": FileType.VIRTUAL,
  158. "name": req["name"],
  159. "location": "",
  160. "size": 0
  161. })
  162. return get_json_result(data=doc.to_json())
  163. except Exception as e:
  164. return server_error_response(e)
  165. @manager.route('/list', methods=['GET'])
  166. @login_required
  167. def list_docs():
  168. kb_id = request.args.get("kb_id")
  169. if not kb_id:
  170. return get_json_result(
  171. data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
  172. keywords = request.args.get("keywords", "")
  173. page_number = int(request.args.get("page", 1))
  174. items_per_page = int(request.args.get("page_size", 15))
  175. orderby = request.args.get("orderby", "create_time")
  176. desc = request.args.get("desc", True)
  177. try:
  178. docs, tol = DocumentService.get_by_kb_id(
  179. kb_id, page_number, items_per_page, orderby, desc, keywords)
  180. return get_json_result(data={"total": tol, "docs": docs})
  181. except Exception as e:
  182. return server_error_response(e)
  183. @manager.route('/thumbnails', methods=['GET'])
  184. @login_required
  185. def thumbnails():
  186. doc_ids = request.args.get("doc_ids").split(",")
  187. if not doc_ids:
  188. return get_json_result(
  189. data=False, retmsg='Lack of "Document ID"', retcode=RetCode.ARGUMENT_ERROR)
  190. try:
  191. docs = DocumentService.get_thumbnails(doc_ids)
  192. return get_json_result(data={d["id"]: d["thumbnail"] for d in docs})
  193. except Exception as e:
  194. return server_error_response(e)
  195. @manager.route('/change_status', methods=['POST'])
  196. @login_required
  197. @validate_request("doc_id", "status")
  198. def change_status():
  199. req = request.json
  200. if str(req["status"]) not in ["0", "1"]:
  201. get_json_result(
  202. data=False,
  203. retmsg='"Status" must be either 0 or 1!',
  204. retcode=RetCode.ARGUMENT_ERROR)
  205. try:
  206. e, doc = DocumentService.get_by_id(req["doc_id"])
  207. if not e:
  208. return get_data_error_result(retmsg="Document not found!")
  209. e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
  210. if not e:
  211. return get_data_error_result(
  212. retmsg="Can't find this knowledgebase!")
  213. if not DocumentService.update_by_id(
  214. req["doc_id"], {"status": str(req["status"])}):
  215. return get_data_error_result(
  216. retmsg="Database error (Document update)!")
  217. if str(req["status"]) == "0":
  218. ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
  219. scripts="ctx._source.available_int=0;",
  220. idxnm=search.index_name(
  221. kb.tenant_id)
  222. )
  223. else:
  224. ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
  225. scripts="ctx._source.available_int=1;",
  226. idxnm=search.index_name(
  227. kb.tenant_id)
  228. )
  229. return get_json_result(data=True)
  230. except Exception as e:
  231. return server_error_response(e)
  232. @manager.route('/rm', methods=['POST'])
  233. @login_required
  234. @validate_request("doc_id")
  235. def rm():
  236. req = request.json
  237. doc_ids = req["doc_id"]
  238. if isinstance(doc_ids, str): doc_ids = [doc_ids]
  239. root_folder = FileService.get_root_folder(current_user.id)
  240. pf_id = root_folder["id"]
  241. FileService.init_knowledgebase_docs(pf_id, current_user.id)
  242. errors = ""
  243. for doc_id in doc_ids:
  244. try:
  245. e, doc = DocumentService.get_by_id(doc_id)
  246. if not e:
  247. return get_data_error_result(retmsg="Document not found!")
  248. tenant_id = DocumentService.get_tenant_id(doc_id)
  249. if not tenant_id:
  250. return get_data_error_result(retmsg="Tenant not found!")
  251. b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
  252. if not DocumentService.remove_document(doc, tenant_id):
  253. return get_data_error_result(
  254. retmsg="Database error (Document removal)!")
  255. f2d = File2DocumentService.get_by_document_id(doc_id)
  256. FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
  257. File2DocumentService.delete_by_document_id(doc_id)
  258. MINIO.rm(b, n)
  259. except Exception as e:
  260. errors += str(e)
  261. if errors:
  262. return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)
  263. return get_json_result(data=True)
  264. @manager.route('/run', methods=['POST'])
  265. @login_required
  266. @validate_request("doc_ids", "run")
  267. def run():
  268. req = request.json
  269. try:
  270. for id in req["doc_ids"]:
  271. info = {"run": str(req["run"]), "progress": 0}
  272. if str(req["run"]) == TaskStatus.RUNNING.value:
  273. info["progress_msg"] = ""
  274. info["chunk_num"] = 0
  275. info["token_num"] = 0
  276. DocumentService.update_by_id(id, info)
  277. # if str(req["run"]) == TaskStatus.CANCEL.value:
  278. tenant_id = DocumentService.get_tenant_id(id)
  279. if not tenant_id:
  280. return get_data_error_result(retmsg="Tenant not found!")
  281. ELASTICSEARCH.deleteByQuery(
  282. Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
  283. if str(req["run"]) == TaskStatus.RUNNING.value:
  284. TaskService.filter_delete([Task.doc_id == id])
  285. e, doc = DocumentService.get_by_id(id)
  286. doc = doc.to_dict()
  287. doc["tenant_id"] = tenant_id
  288. bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
  289. queue_tasks(doc, bucket, name)
  290. return get_json_result(data=True)
  291. except Exception as e:
  292. return server_error_response(e)
  293. @manager.route('/rename', methods=['POST'])
  294. @login_required
  295. @validate_request("doc_id", "name")
  296. def rename():
  297. req = request.json
  298. try:
  299. e, doc = DocumentService.get_by_id(req["doc_id"])
  300. if not e:
  301. return get_data_error_result(retmsg="Document not found!")
  302. if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
  303. doc.name.lower()).suffix:
  304. return get_json_result(
  305. data=False,
  306. retmsg="The extension of file can't be changed",
  307. retcode=RetCode.ARGUMENT_ERROR)
  308. for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
  309. if d.name == req["name"]:
  310. return get_data_error_result(
  311. retmsg="Duplicated document name in the same knowledgebase.")
  312. if not DocumentService.update_by_id(
  313. req["doc_id"], {"name": req["name"]}):
  314. return get_data_error_result(
  315. retmsg="Database error (Document rename)!")
  316. informs = File2DocumentService.get_by_document_id(req["doc_id"])
  317. if informs:
  318. e, file = FileService.get_by_id(informs[0].file_id)
  319. FileService.update_by_id(file.id, {"name": req["name"]})
  320. return get_json_result(data=True)
  321. except Exception as e:
  322. return server_error_response(e)
  323. @manager.route('/get/<doc_id>', methods=['GET'])
  324. # @login_required
  325. def get(doc_id):
  326. try:
  327. e, doc = DocumentService.get_by_id(doc_id)
  328. if not e:
  329. return get_data_error_result(retmsg="Document not found!")
  330. b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
  331. response = flask.make_response(MINIO.get(b, n))
  332. ext = re.search(r"\.([^.]+)$", doc.name)
  333. if ext:
  334. if doc.type == FileType.VISUAL.value:
  335. response.headers.set('Content-Type', 'image/%s' % ext.group(1))
  336. else:
  337. response.headers.set(
  338. 'Content-Type',
  339. 'application/%s' %
  340. ext.group(1))
  341. return response
  342. except Exception as e:
  343. return server_error_response(e)
  344. @manager.route('/change_parser', methods=['POST'])
  345. @login_required
  346. @validate_request("doc_id", "parser_id")
  347. def change_parser():
  348. req = request.json
  349. try:
  350. e, doc = DocumentService.get_by_id(req["doc_id"])
  351. if not e:
  352. return get_data_error_result(retmsg="Document not found!")
  353. if doc.parser_id.lower() == req["parser_id"].lower():
  354. if "parser_config" in req:
  355. if req["parser_config"] == doc.parser_config:
  356. return get_json_result(data=True)
  357. else:
  358. return get_json_result(data=True)
  359. if doc.type == FileType.VISUAL or re.search(
  360. r"\.(ppt|pptx|pages)$", doc.name):
  361. return get_data_error_result(retmsg="Not supported yet!")
  362. e = DocumentService.update_by_id(doc.id,
  363. {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "",
  364. "run": TaskStatus.UNSTART.value})
  365. if not e:
  366. return get_data_error_result(retmsg="Document not found!")
  367. if "parser_config" in req:
  368. DocumentService.update_parser_config(doc.id, req["parser_config"])
  369. if doc.token_num > 0:
  370. e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
  371. doc.process_duation * -1)
  372. if not e:
  373. return get_data_error_result(retmsg="Document not found!")
  374. tenant_id = DocumentService.get_tenant_id(req["doc_id"])
  375. if not tenant_id:
  376. return get_data_error_result(retmsg="Tenant not found!")
  377. ELASTICSEARCH.deleteByQuery(
  378. Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
  379. return get_json_result(data=True)
  380. except Exception as e:
  381. return server_error_response(e)
  382. @manager.route('/image/<image_id>', methods=['GET'])
  383. # @login_required
  384. def get_image(image_id):
  385. try:
  386. bkt, nm = image_id.split("-")
  387. response = flask.make_response(MINIO.get(bkt, nm))
  388. response.headers.set('Content-Type', 'image/JPEG')
  389. return response
  390. except Exception as e:
  391. return server_error_response(e)
  392. @manager.route('/upload_and_parse', methods=['POST'])
  393. @login_required
  394. @validate_request("conversation_id")
  395. def upload_and_parse():
  396. from rag.app import presentation, picture, naive, audio, email
  397. if 'file' not in request.files:
  398. return get_json_result(
  399. data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
  400. file_objs = request.files.getlist('file')
  401. for file_obj in file_objs:
  402. if file_obj.filename == '':
  403. return get_json_result(
  404. data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
  405. e, conv = ConversationService.get_by_id(request.form.get("conversation_id"))
  406. if not e:
  407. return get_data_error_result(retmsg="Conversation not found!")
  408. e, dia = DialogService.get_by_id(conv.dialog_id)
  409. kb_id = dia.kb_ids[0]
  410. e, kb = KnowledgebaseService.get_by_id(kb_id)
  411. if not e:
  412. raise LookupError("Can't find this knowledgebase!")
  413. idxnm = search.index_name(kb.tenant_id)
  414. if not ELASTICSEARCH.indexExist(idxnm):
  415. ELASTICSEARCH.createIdx(idxnm, json.load(
  416. open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
  417. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
  418. err, files = FileService.upload_document(kb, file_objs)
  419. if err:
  420. return get_json_result(
  421. data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
  422. def dummy(prog=None, msg=""):
  423. pass
  424. FACTORY = {
  425. ParserType.PRESENTATION.value: presentation,
  426. ParserType.PICTURE.value: picture,
  427. ParserType.AUDIO.value: audio,
  428. ParserType.EMAIL.value: email
  429. }
  430. parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
  431. exe = ThreadPoolExecutor(max_workers=12)
  432. threads = []
  433. for d, blob in files:
  434. kwargs = {
  435. "callback": dummy,
  436. "parser_config": parser_config,
  437. "from_page": 0,
  438. "to_page": 100000,
  439. "tenant_id": kb.tenant_id,
  440. "lang": kb.language
  441. }
  442. threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
  443. for (docinfo,_), th in zip(files, threads):
  444. docs = []
  445. doc = {
  446. "doc_id": docinfo["id"],
  447. "kb_id": [kb.id]
  448. }
  449. for ck in th.result():
  450. d = deepcopy(doc)
  451. d.update(ck)
  452. md5 = hashlib.md5()
  453. md5.update((ck["content_with_weight"] +
  454. str(d["doc_id"])).encode("utf-8"))
  455. d["_id"] = md5.hexdigest()
  456. d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
  457. d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
  458. if not d.get("image"):
  459. docs.append(d)
  460. continue
  461. output_buffer = BytesIO()
  462. if isinstance(d["image"], bytes):
  463. output_buffer = BytesIO(d["image"])
  464. else:
  465. d["image"].save(output_buffer, format='JPEG')
  466. MINIO.put(kb.id, d["_id"], output_buffer.getvalue())
  467. d["img_id"] = "{}-{}".format(kb.id, d["_id"])
  468. del d["image"]
  469. docs.append(d)
  470. parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
  471. docids = [d["id"] for d, _ in files]
  472. chunk_counts = {id: 0 for id in docids}
  473. token_counts = {id: 0 for id in docids}
  474. es_bulk_size = 64
  475. def embedding(doc_id, cnts, batch_size=16):
  476. nonlocal embd_mdl, chunk_counts, token_counts
  477. vects = []
  478. for i in range(0, len(cnts), batch_size):
  479. vts, c = embd_mdl.encode(cnts[i: i + batch_size])
  480. vects.extend(vts.tolist())
  481. chunk_counts[doc_id] += len(cnts[i:i + batch_size])
  482. token_counts[doc_id] += c
  483. return vects
  484. _, tenant = TenantService.get_by_id(kb.tenant_id)
  485. llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
  486. for doc_id in docids:
  487. cks = [c for c in docs if c["doc_id"] == doc_id]
  488. if False and parser_ids[doc_id] != ParserType.PICTURE.value:
  489. mindmap = MindMapExtractor(llm_bdl)
  490. try:
  491. mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, ensure_ascii=False, indent=2)
  492. if len(mind_map) < 32: raise Exception("Few content: "+mind_map)
  493. cks.append({
  494. "doc_id": doc_id,
  495. "kb_id": [kb.id],
  496. "content_with_weight": mind_map,
  497. "knowledge_graph_kwd": "mind_map"
  498. })
  499. except Exception as e:
  500. stat_logger.error("Mind map generation error:", traceback.format_exc())
  501. vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
  502. assert len(cks) == len(vects)
  503. for i, d in enumerate(cks):
  504. v = vects[i]
  505. d["q_%d_vec" % len(v)] = v
  506. for b in range(0, len(cks), es_bulk_size):
  507. ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
  508. DocumentService.increment_chunk_num(
  509. doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
  510. return get_json_result(data=[d["id"] for d,_ in files])