Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

task_executor.py 18KB


  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import sys
  18. from api.utils.log_utils import initRootLogger
  19. CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
  20. initRootLogger(f"task_executor_{CONSUMER_NO}")
  21. for module in ["pdfminer"]:
  22. module_logger = logging.getLogger(module)
  23. module_logger.setLevel(logging.WARNING)
  24. for module in ["peewee"]:
  25. module_logger = logging.getLogger(module)
  26. module_logger.handlers.clear()
  27. module_logger.propagate = True
  28. from datetime import datetime
  29. import json
  30. import os
  31. import hashlib
  32. import copy
  33. import re
  34. import sys
  35. import time
  36. import threading
  37. from functools import partial
  38. from io import BytesIO
  39. from multiprocessing.context import TimeoutError
  40. from timeit import default_timer as timer
  41. import numpy as np
  42. from api.db import LLMType, ParserType
  43. from api.db.services.dialog_service import keyword_extraction, question_proposal
  44. from api.db.services.document_service import DocumentService
  45. from api.db.services.llm_service import LLMBundle
  46. from api.db.services.task_service import TaskService
  47. from api.db.services.file2document_service import File2DocumentService
  48. from api import settings
  49. from api.db.db_models import close_connection
  50. from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
  51. knowledge_graph, email
  52. from rag.nlp import search, rag_tokenizer
  53. from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
  54. from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
  55. from rag.utils import rmSpace, num_tokens_from_string
  56. from rag.utils.redis_conn import REDIS_CONN, Payload
  57. from rag.utils.storage_factory import STORAGE_IMPL
  58. BATCH_SIZE = 64
  59. FACTORY = {
  60. "general": naive,
  61. ParserType.NAIVE.value: naive,
  62. ParserType.PAPER.value: paper,
  63. ParserType.BOOK.value: book,
  64. ParserType.PRESENTATION.value: presentation,
  65. ParserType.MANUAL.value: manual,
  66. ParserType.LAWS.value: laws,
  67. ParserType.QA.value: qa,
  68. ParserType.TABLE.value: table,
  69. ParserType.RESUME.value: resume,
  70. ParserType.PICTURE.value: picture,
  71. ParserType.ONE.value: one,
  72. ParserType.AUDIO.value: audio,
  73. ParserType.EMAIL.value: email,
  74. ParserType.KG.value: knowledge_graph
  75. }
  76. CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
  77. PAYLOAD: Payload | None = None
  78. BOOT_AT = datetime.now().isoformat()
  79. PENDING_TASKS = 0
  80. LAG_TASKS = 0
  81. mt_lock = threading.Lock()
  82. DONE_TASKS = 0
  83. FAILED_TASKS = 0
  84. CURRENT_TASK = None
  85. def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
  86. global PAYLOAD
  87. if prog is not None and prog < 0:
  88. msg = "[ERROR]" + msg
  89. cancel = TaskService.do_cancel(task_id)
  90. if cancel:
  91. msg += " [Canceled]"
  92. prog = -1
  93. if to_page > 0:
  94. if msg:
  95. msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
  96. d = {"progress_msg": msg}
  97. if prog is not None:
  98. d["progress"] = prog
  99. try:
  100. TaskService.update_progress(task_id, d)
  101. except Exception:
  102. logging.exception(f"set_progress({task_id}) got exception")
  103. close_connection()
  104. if cancel:
  105. if PAYLOAD:
  106. PAYLOAD.ack()
  107. PAYLOAD = None
  108. os._exit(0)
  109. def collect():
  110. global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
  111. try:
  112. PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  113. if not PAYLOAD:
  114. PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
  115. if not PAYLOAD:
  116. time.sleep(1)
  117. return None
  118. except Exception:
  119. logging.exception("Get task event from queue exception")
  120. return None
  121. msg = PAYLOAD.get_message()
  122. if not msg:
  123. return None
  124. if TaskService.do_cancel(msg["id"]):
  125. with mt_lock:
  126. DONE_TASKS += 1
  127. logging.info("Task {} has been canceled.".format(msg["id"]))
  128. return None
  129. task = TaskService.get_task(msg["id"])
  130. if not task:
  131. with mt_lock:
  132. DONE_TASKS += 1
  133. logging.warning("{} empty task!".format(msg["id"]))
  134. return None
  135. if msg.get("type", "") == "raptor":
  136. task["task_type"] = "raptor"
  137. return task
  138. def get_storage_binary(bucket, name):
  139. return STORAGE_IMPL.get(bucket, name)
  140. def build(row):
  141. if row["size"] > DOC_MAXIMUM_SIZE:
  142. set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
  143. (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
  144. return []
  145. callback = partial(
  146. set_progress,
  147. row["id"],
  148. row["from_page"],
  149. row["to_page"])
  150. chunker = FACTORY[row["parser_id"].lower()]
  151. try:
  152. st = timer()
  153. bucket, name = File2DocumentService.get_storage_address(doc_id=row["doc_id"])
  154. binary = get_storage_binary(bucket, name)
  155. logging.info(
  156. "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
  157. except TimeoutError:
  158. callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
  159. logging.exception(
  160. "Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
  161. raise
  162. except Exception as e:
  163. if re.search("(No such file|not found)", str(e)):
  164. callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
  165. else:
  166. callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
  167. logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
  168. raise
  169. try:
  170. cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
  171. to_page=row["to_page"], lang=row["language"], callback=callback,
  172. kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
  173. logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
  174. except Exception as e:
  175. callback(-1, "Internal server error while chunking: %s" %
  176. str(e).replace("'", ""))
  177. logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
  178. raise
  179. docs = []
  180. doc = {
  181. "doc_id": row["doc_id"],
  182. "kb_id": str(row["kb_id"])
  183. }
  184. el = 0
  185. for ck in cks:
  186. d = copy.deepcopy(doc)
  187. d.update(ck)
  188. md5 = hashlib.md5()
  189. md5.update((ck["content_with_weight"] +
  190. str(d["doc_id"])).encode("utf-8"))
  191. d["id"] = md5.hexdigest()
  192. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  193. d["create_timestamp_flt"] = datetime.now().timestamp()
  194. if not d.get("image"):
  195. _ = d.pop("image", None)
  196. d["img_id"] = ""
  197. d["page_num_list"] = json.dumps([])
  198. d["position_list"] = json.dumps([])
  199. d["top_list"] = json.dumps([])
  200. docs.append(d)
  201. continue
  202. try:
  203. output_buffer = BytesIO()
  204. if isinstance(d["image"], bytes):
  205. output_buffer = BytesIO(d["image"])
  206. else:
  207. d["image"].save(output_buffer, format='JPEG')
  208. st = timer()
  209. STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
  210. el += timer() - st
  211. except Exception:
  212. logging.exception(
  213. "Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
  214. raise
  215. d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
  216. del d["image"]
  217. docs.append(d)
  218. logging.info("MINIO PUT({}):{}".format(row["name"], el))
  219. if row["parser_config"].get("auto_keywords", 0):
  220. st = timer()
  221. callback(msg="Start to generate keywords for every chunk ...")
  222. chat_mdl = LLMBundle(row["tenant_id"], LLMType.CHAT, llm_name=row["llm_id"], lang=row["language"])
  223. for d in docs:
  224. d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
  225. row["parser_config"]["auto_keywords"]).split(",")
  226. d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
  227. callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
  228. if row["parser_config"].get("auto_questions", 0):
  229. st = timer()
  230. callback(msg="Start to generate questions for every chunk ...")
  231. chat_mdl = LLMBundle(row["tenant_id"], LLMType.CHAT, llm_name=row["llm_id"], lang=row["language"])
  232. for d in docs:
  233. qst = question_proposal(chat_mdl, d["content_with_weight"], row["parser_config"]["auto_questions"])
  234. d["content_with_weight"] = f"Question: \n{qst}\n\nAnswer:\n" + d["content_with_weight"]
  235. qst = rag_tokenizer.tokenize(qst)
  236. if "content_ltks" in d:
  237. d["content_ltks"] += " " + qst
  238. if "content_sm_ltks" in d:
  239. d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
  240. callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
  241. return docs
  242. def init_kb(row, vector_size: int):
  243. idxnm = search.index_name(row["tenant_id"])
  244. return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
  245. def embedding(docs, mdl, parser_config=None, callback=None):
  246. if parser_config is None:
  247. parser_config = {}
  248. batch_size = 32
  249. tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
  250. re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
  251. tk_count = 0
  252. if len(tts) == len(cnts):
  253. tts_ = np.array([])
  254. for i in range(0, len(tts), batch_size):
  255. vts, c = mdl.encode(tts[i: i + batch_size])
  256. if len(tts_) == 0:
  257. tts_ = vts
  258. else:
  259. tts_ = np.concatenate((tts_, vts), axis=0)
  260. tk_count += c
  261. callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="")
  262. tts = tts_
  263. cnts_ = np.array([])
  264. for i in range(0, len(cnts), batch_size):
  265. vts, c = mdl.encode(cnts[i: i + batch_size])
  266. if len(cnts_) == 0:
  267. cnts_ = vts
  268. else:
  269. cnts_ = np.concatenate((cnts_, vts), axis=0)
  270. tk_count += c
  271. callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
  272. cnts = cnts_
  273. title_w = float(parser_config.get("filename_embd_weight", 0.1))
  274. vects = (title_w * tts + (1 - title_w) *
  275. cnts) if len(tts) == len(cnts) else cnts
  276. assert len(vects) == len(docs)
  277. vector_size = 0
  278. for i, d in enumerate(docs):
  279. v = vects[i].tolist()
  280. vector_size = len(v)
  281. d["q_%d_vec" % len(v)] = v
  282. return tk_count, vector_size
  283. def run_raptor(row, chat_mdl, embd_mdl, callback=None):
  284. vts, _ = embd_mdl.encode(["ok"])
  285. vector_size = len(vts[0])
  286. vctr_nm = "q_%d_vec" % vector_size
  287. chunks = []
  288. for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
  289. fields=["content_with_weight", vctr_nm]):
  290. chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
  291. raptor = Raptor(
  292. row["parser_config"]["raptor"].get("max_cluster", 64),
  293. chat_mdl,
  294. embd_mdl,
  295. row["parser_config"]["raptor"]["prompt"],
  296. row["parser_config"]["raptor"]["max_token"],
  297. row["parser_config"]["raptor"]["threshold"]
  298. )
  299. original_length = len(chunks)
  300. raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
  301. doc = {
  302. "doc_id": row["doc_id"],
  303. "kb_id": [str(row["kb_id"])],
  304. "docnm_kwd": row["name"],
  305. "title_tks": rag_tokenizer.tokenize(row["name"])
  306. }
  307. res = []
  308. tk_count = 0
  309. for content, vctr in chunks[original_length:]:
  310. d = copy.deepcopy(doc)
  311. md5 = hashlib.md5()
  312. md5.update((content + str(d["doc_id"])).encode("utf-8"))
  313. d["id"] = md5.hexdigest()
  314. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  315. d["create_timestamp_flt"] = datetime.now().timestamp()
  316. d[vctr_nm] = vctr.tolist()
  317. d["content_with_weight"] = content
  318. d["content_ltks"] = rag_tokenizer.tokenize(content)
  319. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  320. res.append(d)
  321. tk_count += num_tokens_from_string(content)
  322. return res, tk_count, vector_size
  323. def do_handle_task(r):
  324. callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
  325. try:
  326. embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"])
  327. except Exception as e:
  328. callback(-1, msg=str(e))
  329. raise
  330. if r.get("task_type", "") == "raptor":
  331. try:
  332. chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
  333. cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback)
  334. except Exception as e:
  335. callback(-1, msg=str(e))
  336. raise
  337. else:
  338. st = timer()
  339. cks = build(r)
  340. logging.info("Build chunks({}): {}".format(r["name"], timer() - st))
  341. if cks is None:
  342. return
  343. if not cks:
  344. callback(1., "No chunk! Done!")
  345. return
  346. # TODO: exception handler
  347. ## set_progress(r["did"], -1, "ERROR: ")
  348. callback(
  349. msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
  350. timer() - st)
  351. )
  352. st = timer()
  353. try:
  354. tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback)
  355. except Exception as e:
  356. callback(-1, "Embedding error:{}".format(str(e)))
  357. logging.exception("run_rembedding got exception")
  358. tk_count = 0
  359. raise
  360. logging.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
  361. callback(msg="Finished embedding (in {:.2f}s)! Start to build index!".format(timer() - st))
  362. # logging.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
  363. init_kb(r, vector_size)
  364. chunk_count = len(set([c["id"] for c in cks]))
  365. st = timer()
  366. es_r = ""
  367. es_bulk_size = 4
  368. for b in range(0, len(cks), es_bulk_size):
  369. es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
  370. if b % 128 == 0:
  371. callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
  372. logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
  373. if es_r:
  374. callback(-1, "Insert chunk error, detail info please check log file. Please also check Elasticsearch/Infinity status!")
  375. settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
  376. logging.error('Insert chunk error: ' + str(es_r))
  377. raise Exception('Insert chunk error: ' + str(es_r))
  378. if TaskService.do_cancel(r["id"]):
  379. settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
  380. return
  381. callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
  382. callback(1., "Done!")
  383. DocumentService.increment_chunk_num(
  384. r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
  385. logging.info(
  386. "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
  387. r["id"], tk_count, len(cks), timer() - st))
  388. def handle_task():
  389. global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  390. task = collect()
  391. if task:
  392. try:
  393. logging.info(f"handle_task begin for task {json.dumps(task)}")
  394. with mt_lock:
  395. CURRENT_TASK = copy.deepcopy(task)
  396. do_handle_task(task)
  397. with mt_lock:
  398. DONE_TASKS += 1
  399. CURRENT_TASK = None
  400. logging.info(f"handle_task done for task {json.dumps(task)}")
  401. except Exception:
  402. with mt_lock:
  403. FAILED_TASKS += 1
  404. CURRENT_TASK = None
  405. logging.exception(f"handle_task got exception for task {json.dumps(task)}")
  406. if PAYLOAD:
  407. PAYLOAD.ack()
  408. PAYLOAD = None
  409. def report_status():
  410. global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  411. REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
  412. while True:
  413. try:
  414. now = datetime.now()
  415. group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  416. if group_info is not None:
  417. PENDING_TASKS = int(group_info["pending"])
  418. LAG_TASKS = int(group_info["lag"])
  419. with mt_lock:
  420. heartbeat = json.dumps({
  421. "name": CONSUMER_NAME,
  422. "now": now.isoformat(),
  423. "boot_at": BOOT_AT,
  424. "pending": PENDING_TASKS,
  425. "lag": LAG_TASKS,
  426. "done": DONE_TASKS,
  427. "failed": FAILED_TASKS,
  428. "current": CURRENT_TASK,
  429. })
  430. REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
  431. logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
  432. expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
  433. if expired > 0:
  434. REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
  435. except Exception:
  436. logging.exception("report_status got exception")
  437. time.sleep(30)
  438. def main():
  439. settings.init_settings()
  440. background_thread = threading.Thread(target=report_status)
  441. background_thread.daemon = True
  442. background_thread.start()
  443. while True:
  444. handle_task()
  445. if __name__ == "__main__":
  446. main()