You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_executor.py 21KB


  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # from beartype import BeartypeConf
  16. # from beartype.claw import beartype_all # <-- you didn't sign up for this
  17. # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
  18. import logging
  19. import sys
  20. from api.utils.log_utils import initRootLogger
  21. CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
  22. CONSUMER_NAME = "task_executor_" + CONSUMER_NO
  23. initRootLogger(CONSUMER_NAME)
  24. for module in ["pdfminer"]:
  25. module_logger = logging.getLogger(module)
  26. module_logger.setLevel(logging.WARNING)
  27. for module in ["peewee"]:
  28. module_logger = logging.getLogger(module)
  29. module_logger.handlers.clear()
  30. module_logger.propagate = True
  31. from datetime import datetime
  32. import json
  33. import os
  34. import hashlib
  35. import copy
  36. import re
  37. import sys
  38. import time
  39. import threading
  40. from functools import partial
  41. from io import BytesIO
  42. from multiprocessing.context import TimeoutError
  43. from timeit import default_timer as timer
  44. import tracemalloc
  45. import numpy as np
  46. from api.db import LLMType, ParserType
  47. from api.db.services.dialog_service import keyword_extraction, question_proposal
  48. from api.db.services.document_service import DocumentService
  49. from api.db.services.llm_service import LLMBundle
  50. from api.db.services.task_service import TaskService
  51. from api.db.services.file2document_service import File2DocumentService
  52. from api import settings
  53. from api.versions import get_ragflow_version
  54. from api.db.db_models import close_connection
  55. from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
  56. knowledge_graph, email
  57. from rag.nlp import search, rag_tokenizer
  58. from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
  59. from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings
  60. from rag.utils import rmSpace, num_tokens_from_string
  61. from rag.utils.redis_conn import REDIS_CONN, Payload
  62. from rag.utils.storage_factory import STORAGE_IMPL
  63. BATCH_SIZE = 64
  64. FACTORY = {
  65. "general": naive,
  66. ParserType.NAIVE.value: naive,
  67. ParserType.PAPER.value: paper,
  68. ParserType.BOOK.value: book,
  69. ParserType.PRESENTATION.value: presentation,
  70. ParserType.MANUAL.value: manual,
  71. ParserType.LAWS.value: laws,
  72. ParserType.QA.value: qa,
  73. ParserType.TABLE.value: table,
  74. ParserType.RESUME.value: resume,
  75. ParserType.PICTURE.value: picture,
  76. ParserType.ONE.value: one,
  77. ParserType.AUDIO.value: audio,
  78. ParserType.EMAIL.value: email,
  79. ParserType.KG.value: knowledge_graph
  80. }
  81. CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
  82. PAYLOAD: Payload | None = None
  83. BOOT_AT = datetime.now().isoformat()
  84. PENDING_TASKS = 0
  85. LAG_TASKS = 0
  86. mt_lock = threading.Lock()
  87. DONE_TASKS = 0
  88. FAILED_TASKS = 0
  89. CURRENT_TASK = None
  90. def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
  91. global PAYLOAD
  92. if prog is not None and prog < 0:
  93. msg = "[ERROR]" + msg
  94. cancel = TaskService.do_cancel(task_id)
  95. if cancel:
  96. msg += " [Canceled]"
  97. prog = -1
  98. if to_page > 0:
  99. if msg:
  100. msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
  101. d = {"progress_msg": msg}
  102. if prog is not None:
  103. d["progress"] = prog
  104. try:
  105. logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
  106. TaskService.update_progress(task_id, d)
  107. except Exception:
  108. logging.exception(f"set_progress({task_id}) got exception")
  109. close_connection()
  110. if cancel:
  111. if PAYLOAD:
  112. PAYLOAD.ack()
  113. PAYLOAD = None
  114. os._exit(0)
  115. def collect():
  116. global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
  117. try:
  118. PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  119. if not PAYLOAD:
  120. PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
  121. if not PAYLOAD:
  122. time.sleep(1)
  123. return None
  124. except Exception:
  125. logging.exception("Get task event from queue exception")
  126. return None
  127. msg = PAYLOAD.get_message()
  128. if not msg:
  129. return None
  130. if TaskService.do_cancel(msg["id"]):
  131. with mt_lock:
  132. DONE_TASKS += 1
  133. logging.info("Task {} has been canceled.".format(msg["id"]))
  134. return None
  135. task = TaskService.get_task(msg["id"])
  136. if not task:
  137. with mt_lock:
  138. DONE_TASKS += 1
  139. logging.warning("{} empty task!".format(msg["id"]))
  140. return None
  141. if msg.get("type", "") == "raptor":
  142. task["task_type"] = "raptor"
  143. return task
  144. def get_storage_binary(bucket, name):
  145. return STORAGE_IMPL.get(bucket, name)
  146. def build(row):
  147. if row["size"] > DOC_MAXIMUM_SIZE:
  148. set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
  149. (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
  150. return []
  151. callback = partial(
  152. set_progress,
  153. row["id"],
  154. row["from_page"],
  155. row["to_page"])
  156. chunker = FACTORY[row["parser_id"].lower()]
  157. try:
  158. st = timer()
  159. bucket, name = File2DocumentService.get_storage_address(doc_id=row["doc_id"])
  160. binary = get_storage_binary(bucket, name)
  161. logging.info(
  162. "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
  163. except TimeoutError:
  164. callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
  165. logging.exception(
  166. "Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
  167. raise
  168. except Exception as e:
  169. if re.search("(No such file|not found)", str(e)):
  170. callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
  171. else:
  172. callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
  173. logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
  174. raise
  175. try:
  176. cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
  177. to_page=row["to_page"], lang=row["language"], callback=callback,
  178. kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
  179. logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
  180. except Exception as e:
  181. callback(-1, "Internal server error while chunking: %s" %
  182. str(e).replace("'", ""))
  183. logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
  184. raise
  185. docs = []
  186. doc = {
  187. "doc_id": row["doc_id"],
  188. "kb_id": str(row["kb_id"])
  189. }
  190. el = 0
  191. for ck in cks:
  192. d = copy.deepcopy(doc)
  193. d.update(ck)
  194. md5 = hashlib.md5()
  195. md5.update((ck["content_with_weight"] +
  196. str(d["doc_id"])).encode("utf-8"))
  197. d["id"] = md5.hexdigest()
  198. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  199. d["create_timestamp_flt"] = datetime.now().timestamp()
  200. if not d.get("image"):
  201. _ = d.pop("image", None)
  202. d["img_id"] = ""
  203. d["page_num_list"] = json.dumps([])
  204. d["position_list"] = json.dumps([])
  205. d["top_list"] = json.dumps([])
  206. docs.append(d)
  207. continue
  208. try:
  209. output_buffer = BytesIO()
  210. if isinstance(d["image"], bytes):
  211. output_buffer = BytesIO(d["image"])
  212. else:
  213. d["image"].save(output_buffer, format='JPEG')
  214. st = timer()
  215. STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
  216. el += timer() - st
  217. except Exception:
  218. logging.exception(
  219. "Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
  220. raise
  221. d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
  222. del d["image"]
  223. docs.append(d)
  224. logging.info("MINIO PUT({}):{}".format(row["name"], el))
  225. if row["parser_config"].get("auto_keywords", 0):
  226. st = timer()
  227. callback(msg="Start to generate keywords for every chunk ...")
  228. chat_mdl = LLMBundle(row["tenant_id"], LLMType.CHAT, llm_name=row["llm_id"], lang=row["language"])
  229. for d in docs:
  230. d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
  231. row["parser_config"]["auto_keywords"]).split(",")
  232. d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
  233. callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
  234. if row["parser_config"].get("auto_questions", 0):
  235. st = timer()
  236. callback(msg="Start to generate questions for every chunk ...")
  237. chat_mdl = LLMBundle(row["tenant_id"], LLMType.CHAT, llm_name=row["llm_id"], lang=row["language"])
  238. for d in docs:
  239. qst = question_proposal(chat_mdl, d["content_with_weight"], row["parser_config"]["auto_questions"])
  240. d["content_with_weight"] = f"Question: \n{qst}\n\nAnswer:\n" + d["content_with_weight"]
  241. qst = rag_tokenizer.tokenize(qst)
  242. if "content_ltks" in d:
  243. d["content_ltks"] += " " + qst
  244. if "content_sm_ltks" in d:
  245. d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
  246. callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
  247. return docs
  248. def init_kb(row, vector_size: int):
  249. idxnm = search.index_name(row["tenant_id"])
  250. return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
  251. def embedding(docs, mdl, parser_config=None, callback=None):
  252. if parser_config is None:
  253. parser_config = {}
  254. batch_size = 32
  255. tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
  256. re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]
  257. tk_count = 0
  258. if len(tts) == len(cnts):
  259. tts_ = np.array([])
  260. for i in range(0, len(tts), batch_size):
  261. vts, c = mdl.encode(tts[i: i + batch_size])
  262. if len(tts_) == 0:
  263. tts_ = vts
  264. else:
  265. tts_ = np.concatenate((tts_, vts), axis=0)
  266. tk_count += c
  267. callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="")
  268. tts = tts_
  269. cnts_ = np.array([])
  270. for i in range(0, len(cnts), batch_size):
  271. vts, c = mdl.encode(cnts[i: i + batch_size])
  272. if len(cnts_) == 0:
  273. cnts_ = vts
  274. else:
  275. cnts_ = np.concatenate((cnts_, vts), axis=0)
  276. tk_count += c
  277. callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
  278. cnts = cnts_
  279. title_w = float(parser_config.get("filename_embd_weight", 0.1))
  280. vects = (title_w * tts + (1 - title_w) *
  281. cnts) if len(tts) == len(cnts) else cnts
  282. assert len(vects) == len(docs)
  283. vector_size = 0
  284. for i, d in enumerate(docs):
  285. v = vects[i].tolist()
  286. vector_size = len(v)
  287. d["q_%d_vec" % len(v)] = v
  288. return tk_count, vector_size
  289. def run_raptor(row, chat_mdl, embd_mdl, callback=None):
  290. vts, _ = embd_mdl.encode(["ok"])
  291. vector_size = len(vts[0])
  292. vctr_nm = "q_%d_vec" % vector_size
  293. chunks = []
  294. for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
  295. fields=["content_with_weight", vctr_nm]):
  296. chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
  297. raptor = Raptor(
  298. row["parser_config"]["raptor"].get("max_cluster", 64),
  299. chat_mdl,
  300. embd_mdl,
  301. row["parser_config"]["raptor"]["prompt"],
  302. row["parser_config"]["raptor"]["max_token"],
  303. row["parser_config"]["raptor"]["threshold"]
  304. )
  305. original_length = len(chunks)
  306. chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
  307. doc = {
  308. "doc_id": row["doc_id"],
  309. "kb_id": [str(row["kb_id"])],
  310. "docnm_kwd": row["name"],
  311. "title_tks": rag_tokenizer.tokenize(row["name"])
  312. }
  313. res = []
  314. tk_count = 0
  315. for content, vctr in chunks[original_length:]:
  316. d = copy.deepcopy(doc)
  317. md5 = hashlib.md5()
  318. md5.update((content + str(d["doc_id"])).encode("utf-8"))
  319. d["id"] = md5.hexdigest()
  320. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  321. d["create_timestamp_flt"] = datetime.now().timestamp()
  322. d[vctr_nm] = vctr.tolist()
  323. d["content_with_weight"] = content
  324. d["content_ltks"] = rag_tokenizer.tokenize(content)
  325. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  326. res.append(d)
  327. tk_count += num_tokens_from_string(content)
  328. return res, tk_count, vector_size
  329. def do_handle_task(r):
  330. callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
  331. try:
  332. embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"])
  333. except Exception as e:
  334. callback(-1, msg=str(e))
  335. raise
  336. if r.get("task_type", "") == "raptor":
  337. try:
  338. chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
  339. cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback)
  340. except Exception as e:
  341. callback(-1, msg=str(e))
  342. raise
  343. else:
  344. st = timer()
  345. cks = build(r)
  346. logging.info("Build chunks({}): {}".format(r["name"], timer() - st))
  347. if cks is None:
  348. return
  349. if not cks:
  350. callback(1., "No chunk! Done!")
  351. return
  352. # TODO: exception handler
  353. ## set_progress(r["did"], -1, "ERROR: ")
  354. callback(
  355. msg="Generate {} chunks ({:.2f}s). Embedding chunks.".format(len(cks), timer() - st)
  356. )
  357. st = timer()
  358. try:
  359. tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback)
  360. except Exception as e:
  361. callback(-1, "Embedding error:{}".format(str(e)))
  362. logging.exception("run_rembedding got exception")
  363. tk_count = 0
  364. raise
  365. logging.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
  366. callback(msg="Finished embedding ({:.2f}s)!".format(timer() - st))
  367. # logging.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
  368. init_kb(r, vector_size)
  369. chunk_count = len(set([c["id"] for c in cks]))
  370. st = timer()
  371. es_r = ""
  372. es_bulk_size = 4
  373. for b in range(0, len(cks), es_bulk_size):
  374. es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
  375. if b % 128 == 0:
  376. callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
  377. logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
  378. if es_r:
  379. callback(-1,
  380. "Insert chunk error, detail info please check log file. Please also check Elasticsearch/Infinity status!")
  381. settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
  382. logging.error('Insert chunk error: ' + str(es_r))
  383. raise Exception('Insert chunk error: ' + str(es_r))
  384. if TaskService.do_cancel(r["id"]):
  385. settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
  386. return
  387. callback(1., msg="Index cost {:.2f}s.".format(timer() - st))
  388. DocumentService.increment_chunk_num(
  389. r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
  390. logging.info(
  391. "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
  392. r["id"], tk_count, len(cks), timer() - st))
  393. def handle_task():
  394. global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  395. task = collect()
  396. if task:
  397. try:
  398. logging.info(f"handle_task begin for task {json.dumps(task)}")
  399. with mt_lock:
  400. CURRENT_TASK = copy.deepcopy(task)
  401. do_handle_task(task)
  402. with mt_lock:
  403. DONE_TASKS += 1
  404. CURRENT_TASK = None
  405. logging.info(f"handle_task done for task {json.dumps(task)}")
  406. except Exception:
  407. with mt_lock:
  408. FAILED_TASKS += 1
  409. CURRENT_TASK = None
  410. logging.exception(f"handle_task got exception for task {json.dumps(task)}")
  411. if PAYLOAD:
  412. PAYLOAD.ack()
  413. PAYLOAD = None
  414. def report_status():
  415. global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  416. REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
  417. while True:
  418. try:
  419. now = datetime.now()
  420. group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  421. if group_info is not None:
  422. PENDING_TASKS = int(group_info["pending"])
  423. LAG_TASKS = int(group_info["lag"])
  424. with mt_lock:
  425. heartbeat = json.dumps({
  426. "name": CONSUMER_NAME,
  427. "now": now.isoformat(),
  428. "boot_at": BOOT_AT,
  429. "pending": PENDING_TASKS,
  430. "lag": LAG_TASKS,
  431. "done": DONE_TASKS,
  432. "failed": FAILED_TASKS,
  433. "current": CURRENT_TASK,
  434. })
  435. REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
  436. logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
  437. expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
  438. if expired > 0:
  439. REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
  440. except Exception:
  441. logging.exception("report_status got exception")
  442. time.sleep(30)
  443. def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
  444. msg = ""
  445. if dump_full:
  446. stats2 = snapshot2.statistics('lineno')
  447. msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
  448. for stat in stats2[:10]:
  449. msg += f"{stat}\n"
  450. stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
  451. msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
  452. for stat in stats1_vs_2[:10]:
  453. msg += f"{stat}\n"
  454. msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
  455. for stat in stats1_vs_2[:3]:
  456. msg += '\n'.join(stat.traceback.format())
  457. logging.info(msg)
  458. def main():
  459. logging.info(r"""
  460. ______ __ ______ __
  461. /_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
  462. / / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/
  463. / / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
  464. /_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
  465. """)
  466. logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
  467. settings.init_settings()
  468. print_rag_settings()
  469. background_thread = threading.Thread(target=report_status)
  470. background_thread.daemon = True
  471. background_thread.start()
  472. TRACE_MALLOC_DELTA = int(os.environ.get('TRACE_MALLOC_DELTA', "0"))
  473. TRACE_MALLOC_FULL = int(os.environ.get('TRACE_MALLOC_FULL', "0"))
  474. if TRACE_MALLOC_DELTA > 0:
  475. if TRACE_MALLOC_FULL < TRACE_MALLOC_DELTA:
  476. TRACE_MALLOC_FULL = TRACE_MALLOC_DELTA
  477. tracemalloc.start()
  478. snapshot1 = tracemalloc.take_snapshot()
  479. while True:
  480. handle_task()
  481. num_tasks = DONE_TASKS + FAILED_TASKS
  482. if TRACE_MALLOC_DELTA > 0 and num_tasks > 0 and num_tasks % TRACE_MALLOC_DELTA == 0:
  483. snapshot2 = tracemalloc.take_snapshot()
  484. analyze_heap(snapshot1, snapshot2, int(num_tasks / TRACE_MALLOC_DELTA), num_tasks % TRACE_MALLOC_FULL == 0)
  485. snapshot1 = snapshot2
  486. snapshot2 = None
  487. if __name__ == "__main__":
  488. main()