Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

task_executor.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import inspect
  18. from api.utils.log_utils import initRootLogger
  19. initRootLogger(inspect.getfile(inspect.currentframe()))
  20. for module in ["pdfminer"]:
  21. module_logger = logging.getLogger(module)
  22. module_logger.setLevel(logging.WARNING)
  23. for module in ["peewee"]:
  24. module_logger = logging.getLogger(module)
  25. module_logger.handlers.clear()
  26. module_logger.propagate = True
  27. import datetime
  28. import json
  29. import os
  30. import hashlib
  31. import copy
  32. import re
  33. import sys
  34. import time
  35. from concurrent.futures import ThreadPoolExecutor
  36. from functools import partial
  37. from io import BytesIO
  38. from multiprocessing.context import TimeoutError
  39. from timeit import default_timer as timer
  40. import numpy as np
  41. import pandas as pd
  42. from api.db import LLMType, ParserType
  43. from api.db.services.dialog_service import keyword_extraction, question_proposal
  44. from api.db.services.document_service import DocumentService
  45. from api.db.services.llm_service import LLMBundle
  46. from api.db.services.task_service import TaskService
  47. from api.db.services.file2document_service import File2DocumentService
  48. from api.settings import retrievaler, docStoreConn
  49. from api.db.db_models import close_connection
  50. from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
  51. from rag.nlp import search, rag_tokenizer
  52. from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
  53. from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
  54. from rag.utils import rmSpace, num_tokens_from_string
  55. from rag.utils.redis_conn import REDIS_CONN, Payload
  56. from rag.utils.storage_factory import STORAGE_IMPL
  57. BATCH_SIZE = 64
  58. FACTORY = {
  59. "general": naive,
  60. ParserType.NAIVE.value: naive,
  61. ParserType.PAPER.value: paper,
  62. ParserType.BOOK.value: book,
  63. ParserType.PRESENTATION.value: presentation,
  64. ParserType.MANUAL.value: manual,
  65. ParserType.LAWS.value: laws,
  66. ParserType.QA.value: qa,
  67. ParserType.TABLE.value: table,
  68. ParserType.RESUME.value: resume,
  69. ParserType.PICTURE.value: picture,
  70. ParserType.ONE.value: one,
  71. ParserType.AUDIO.value: audio,
  72. ParserType.EMAIL.value: email,
  73. ParserType.KG.value: knowledge_graph
  74. }
  75. CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
  76. PAYLOAD: Payload | None = None
  77. def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
  78. global PAYLOAD
  79. if prog is not None and prog < 0:
  80. msg = "[ERROR]" + msg
  81. cancel = TaskService.do_cancel(task_id)
  82. if cancel:
  83. msg += " [Canceled]"
  84. prog = -1
  85. if to_page > 0:
  86. if msg:
  87. msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
  88. d = {"progress_msg": msg}
  89. if prog is not None:
  90. d["progress"] = prog
  91. try:
  92. TaskService.update_progress(task_id, d)
  93. except Exception:
  94. logging.exception(f"set_progress({task_id}) got exception")
  95. close_connection()
  96. if cancel:
  97. if PAYLOAD:
  98. PAYLOAD.ack()
  99. PAYLOAD = None
  100. os._exit(0)
  101. def collect():
  102. global CONSUMER_NAME, PAYLOAD
  103. try:
  104. PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  105. if not PAYLOAD:
  106. PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
  107. if not PAYLOAD:
  108. time.sleep(1)
  109. return pd.DataFrame()
  110. except Exception:
  111. logging.exception("Get task event from queue exception")
  112. return pd.DataFrame()
  113. msg = PAYLOAD.get_message()
  114. if not msg:
  115. return pd.DataFrame()
  116. if TaskService.do_cancel(msg["id"]):
  117. logging.info("Task {} has been canceled.".format(msg["id"]))
  118. return pd.DataFrame()
  119. tasks = TaskService.get_tasks(msg["id"])
  120. if not tasks:
  121. logging.warning("{} empty task!".format(msg["id"]))
  122. return []
  123. tasks = pd.DataFrame(tasks)
  124. if msg.get("type", "") == "raptor":
  125. tasks["task_type"] = "raptor"
  126. return tasks
  127. def get_storage_binary(bucket, name):
  128. return STORAGE_IMPL.get(bucket, name)
  129. def build(row):
  130. if row["size"] > DOC_MAXIMUM_SIZE:
  131. set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
  132. (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
  133. return []
  134. callback = partial(
  135. set_progress,
  136. row["id"],
  137. row["from_page"],
  138. row["to_page"])
  139. chunker = FACTORY[row["parser_id"].lower()]
  140. try:
  141. st = timer()
  142. bucket, name = File2DocumentService.get_storage_address(doc_id=row["doc_id"])
  143. binary = get_storage_binary(bucket, name)
  144. logging.info(
  145. "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
  146. except TimeoutError:
  147. callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
  148. logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
  149. return
  150. except Exception as e:
  151. if re.search("(No such file|not found)", str(e)):
  152. callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
  153. else:
  154. callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
  155. logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
  156. return
  157. try:
  158. cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
  159. to_page=row["to_page"], lang=row["language"], callback=callback,
  160. kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
  161. logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
  162. except Exception as e:
  163. callback(-1, "Internal server error while chunking: %s" %
  164. str(e).replace("'", ""))
  165. logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
  166. return
  167. docs = []
  168. doc = {
  169. "doc_id": row["doc_id"],
  170. "kb_id": str(row["kb_id"])
  171. }
  172. el = 0
  173. for ck in cks:
  174. d = copy.deepcopy(doc)
  175. d.update(ck)
  176. md5 = hashlib.md5()
  177. md5.update((ck["content_with_weight"] +
  178. str(d["doc_id"])).encode("utf-8"))
  179. d["id"] = md5.hexdigest()
  180. d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
  181. d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
  182. if not d.get("image"):
  183. d["img_id"] = ""
  184. d["page_num_list"] = json.dumps([])
  185. d["position_list"] = json.dumps([])
  186. d["top_list"] = json.dumps([])
  187. docs.append(d)
  188. continue
  189. try:
  190. output_buffer = BytesIO()
  191. if isinstance(d["image"], bytes):
  192. output_buffer = BytesIO(d["image"])
  193. else:
  194. d["image"].save(output_buffer, format='JPEG')
  195. st = timer()
  196. STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
  197. el += timer() - st
  198. except Exception:
  199. logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
  200. d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
  201. del d["image"]
  202. docs.append(d)
  203. logging.info("MINIO PUT({}):{}".format(row["name"], el))
  204. if row["parser_config"].get("auto_keywords", 0):
  205. st = timer()
  206. callback(msg="Start to generate keywords for every chunk ...")
  207. chat_mdl = LLMBundle(row["tenant_id"], LLMType.CHAT, llm_name=row["llm_id"], lang=row["language"])
  208. for d in docs:
  209. d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
  210. row["parser_config"]["auto_keywords"]).split(",")
  211. d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
  212. callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st))
  213. if row["parser_config"].get("auto_questions", 0):
  214. st = timer()
  215. callback(msg="Start to generate questions for every chunk ...")
  216. chat_mdl = LLMBundle(row["tenant_id"], LLMType.CHAT, llm_name=row["llm_id"], lang=row["language"])
  217. for d in docs:
  218. qst = question_proposal(chat_mdl, d["content_with_weight"], row["parser_config"]["auto_questions"])
  219. d["content_with_weight"] = f"Question: \n{qst}\n\nAnswer:\n" + d["content_with_weight"]
  220. qst = rag_tokenizer.tokenize(qst)
  221. if "content_ltks" in d:
  222. d["content_ltks"] += " " + qst
  223. if "content_sm_ltks" in d:
  224. d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
  225. callback(msg="Question generation completed in {:.2f}s".format(timer()-st))
  226. return docs
  227. def init_kb(row, vector_size: int):
  228. idxnm = search.index_name(row["tenant_id"])
  229. return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
  230. def embedding(docs, mdl, parser_config=None, callback=None):
  231. if parser_config is None:
  232. parser_config = {}
  233. batch_size = 32
  234. tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
  235. re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
  236. tk_count = 0
  237. if len(tts) == len(cnts):
  238. tts_ = np.array([])
  239. for i in range(0, len(tts), batch_size):
  240. vts, c = mdl.encode(tts[i: i + batch_size])
  241. if len(tts_) == 0:
  242. tts_ = vts
  243. else:
  244. tts_ = np.concatenate((tts_, vts), axis=0)
  245. tk_count += c
  246. callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="")
  247. tts = tts_
  248. cnts_ = np.array([])
  249. for i in range(0, len(cnts), batch_size):
  250. vts, c = mdl.encode(cnts[i: i + batch_size])
  251. if len(cnts_) == 0:
  252. cnts_ = vts
  253. else:
  254. cnts_ = np.concatenate((cnts_, vts), axis=0)
  255. tk_count += c
  256. callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
  257. cnts = cnts_
  258. title_w = float(parser_config.get("filename_embd_weight", 0.1))
  259. vects = (title_w * tts + (1 - title_w) *
  260. cnts) if len(tts) == len(cnts) else cnts
  261. assert len(vects) == len(docs)
  262. vector_size = 0
  263. for i, d in enumerate(docs):
  264. v = vects[i].tolist()
  265. vector_size = len(v)
  266. d["q_%d_vec" % len(v)] = v
  267. return tk_count, vector_size
  268. def run_raptor(row, chat_mdl, embd_mdl, callback=None):
  269. vts, _ = embd_mdl.encode(["ok"])
  270. vector_size = len(vts[0])
  271. vctr_nm = "q_%d_vec" % vector_size
  272. chunks = []
  273. for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
  274. chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
  275. raptor = Raptor(
  276. row["parser_config"]["raptor"].get("max_cluster", 64),
  277. chat_mdl,
  278. embd_mdl,
  279. row["parser_config"]["raptor"]["prompt"],
  280. row["parser_config"]["raptor"]["max_token"],
  281. row["parser_config"]["raptor"]["threshold"]
  282. )
  283. original_length = len(chunks)
  284. raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
  285. doc = {
  286. "doc_id": row["doc_id"],
  287. "kb_id": [str(row["kb_id"])],
  288. "docnm_kwd": row["name"],
  289. "title_tks": rag_tokenizer.tokenize(row["name"])
  290. }
  291. res = []
  292. tk_count = 0
  293. for content, vctr in chunks[original_length:]:
  294. d = copy.deepcopy(doc)
  295. md5 = hashlib.md5()
  296. md5.update((content + str(d["doc_id"])).encode("utf-8"))
  297. d["id"] = md5.hexdigest()
  298. d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
  299. d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
  300. d[vctr_nm] = vctr.tolist()
  301. d["content_with_weight"] = content
  302. d["content_ltks"] = rag_tokenizer.tokenize(content)
  303. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  304. res.append(d)
  305. tk_count += num_tokens_from_string(content)
  306. return res, tk_count, vector_size
  307. def main():
  308. rows = collect()
  309. if len(rows) == 0:
  310. return
  311. for _, r in rows.iterrows():
  312. callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
  313. try:
  314. embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"])
  315. except Exception as e:
  316. callback(-1, msg=str(e))
  317. logging.exception("LLMBundle got exception")
  318. continue
  319. if r.get("task_type", "") == "raptor":
  320. try:
  321. chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
  322. cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback)
  323. except Exception as e:
  324. callback(-1, msg=str(e))
  325. logging.exception("run_raptor got exception")
  326. continue
  327. else:
  328. st = timer()
  329. cks = build(r)
  330. logging.info("Build chunks({}): {}".format(r["name"], timer() - st))
  331. if cks is None:
  332. continue
  333. if not cks:
  334. callback(1., "No chunk! Done!")
  335. continue
  336. # TODO: exception handler
  337. ## set_progress(r["did"], -1, "ERROR: ")
  338. callback(
  339. msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st)
  340. )
  341. st = timer()
  342. try:
  343. tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback)
  344. except Exception as e:
  345. callback(-1, "Embedding error:{}".format(str(e)))
  346. logging.exception("run_rembedding got exception")
  347. tk_count = 0
  348. logging.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
  349. callback(msg="Finished embedding (in {:.2f}s)! Start to build index!".format(timer() - st))
  350. # logging.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
  351. init_kb(r, vector_size)
  352. chunk_count = len(set([c["id"] for c in cks]))
  353. st = timer()
  354. es_r = ""
  355. es_bulk_size = 4
  356. for b in range(0, len(cks), es_bulk_size):
  357. es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
  358. if b % 128 == 0:
  359. callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
  360. logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
  361. if es_r:
  362. callback(-1, f"Insert chunk error, detail info please check {LOG_FILE}. Please also check ES status!")
  363. docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
  364. logging.error('Insert chunk error: ' + str(es_r))
  365. else:
  366. if TaskService.do_cancel(r["id"]):
  367. docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
  368. continue
  369. callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
  370. callback(1., "Done!")
  371. DocumentService.increment_chunk_num(
  372. r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
  373. logging.info(
  374. "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
  375. r["id"], tk_count, len(cks), timer() - st))
  376. def report_status():
  377. global CONSUMER_NAME
  378. while True:
  379. try:
  380. obj = REDIS_CONN.get("TASKEXE")
  381. if not obj: obj = {}
  382. else: obj = json.loads(obj)
  383. if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = []
  384. obj[CONSUMER_NAME].append(timer())
  385. obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:]
  386. REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
  387. except Exception:
  388. logging.exception("report_status got exception")
  389. time.sleep(30)
  390. if __name__ == "__main__":
  391. exe = ThreadPoolExecutor(max_workers=1)
  392. exe.submit(report_status)
  393. while True:
  394. main()
  395. if PAYLOAD:
  396. PAYLOAD.ack()
  397. PAYLOAD = None