您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

task_executor.py 24KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # from beartype import BeartypeConf
  16. # from beartype.claw import beartype_all # <-- you didn't sign up for this
  17. # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
  18. import sys
  19. from api.utils.log_utils import initRootLogger
  20. from graphrag.utils import get_llm_cache, set_llm_cache
  21. CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
  22. CONSUMER_NAME = "task_executor_" + CONSUMER_NO
  23. initRootLogger(CONSUMER_NAME)
  24. import logging
  25. import os
  26. from datetime import datetime
  27. import json
  28. import xxhash
  29. import copy
  30. import re
  31. import time
  32. import threading
  33. from functools import partial
  34. from io import BytesIO
  35. from multiprocessing.context import TimeoutError
  36. from timeit import default_timer as timer
  37. import tracemalloc
  38. import numpy as np
  39. from peewee import DoesNotExist
  40. from api.db import LLMType, ParserType, TaskStatus
  41. from api.db.services.dialog_service import keyword_extraction, question_proposal
  42. from api.db.services.document_service import DocumentService
  43. from api.db.services.llm_service import LLMBundle
  44. from api.db.services.task_service import TaskService
  45. from api.db.services.file2document_service import File2DocumentService
  46. from api import settings
  47. from api.versions import get_ragflow_version
  48. from api.db.db_models import close_connection
  49. from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
  50. knowledge_graph, email
  51. from rag.nlp import search, rag_tokenizer
  52. from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
  53. from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings
  54. from rag.utils import num_tokens_from_string
  55. from rag.utils.redis_conn import REDIS_CONN, Payload
  56. from rag.utils.storage_factory import STORAGE_IMPL
  57. BATCH_SIZE = 64
  58. FACTORY = {
  59. "general": naive,
  60. ParserType.NAIVE.value: naive,
  61. ParserType.PAPER.value: paper,
  62. ParserType.BOOK.value: book,
  63. ParserType.PRESENTATION.value: presentation,
  64. ParserType.MANUAL.value: manual,
  65. ParserType.LAWS.value: laws,
  66. ParserType.QA.value: qa,
  67. ParserType.TABLE.value: table,
  68. ParserType.RESUME.value: resume,
  69. ParserType.PICTURE.value: picture,
  70. ParserType.ONE.value: one,
  71. ParserType.AUDIO.value: audio,
  72. ParserType.EMAIL.value: email,
  73. ParserType.KG.value: knowledge_graph
  74. }
  75. CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
  76. PAYLOAD: Payload | None = None
  77. BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
  78. PENDING_TASKS = 0
  79. LAG_TASKS = 0
  80. mt_lock = threading.Lock()
  81. DONE_TASKS = 0
  82. FAILED_TASKS = 0
  83. CURRENT_TASK = None
  84. class TaskCanceledException(Exception):
  85. def __init__(self, msg):
  86. self.msg = msg
  87. def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
  88. global PAYLOAD
  89. if prog is not None and prog < 0:
  90. msg = "[ERROR]" + msg
  91. try:
  92. cancel = TaskService.do_cancel(task_id)
  93. except DoesNotExist:
  94. logging.warning(f"set_progress task {task_id} is unknown")
  95. if PAYLOAD:
  96. PAYLOAD.ack()
  97. PAYLOAD = None
  98. return
  99. if cancel:
  100. msg += " [Canceled]"
  101. prog = -1
  102. if to_page > 0:
  103. if msg:
  104. msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
  105. if msg:
  106. msg = datetime.now().strftime("%H:%M:%S") + " " + msg
  107. d = {"progress_msg": msg}
  108. if prog is not None:
  109. d["progress"] = prog
  110. logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
  111. try:
  112. TaskService.update_progress(task_id, d)
  113. except DoesNotExist:
  114. logging.warning(f"set_progress task {task_id} is unknown")
  115. if PAYLOAD:
  116. PAYLOAD.ack()
  117. PAYLOAD = None
  118. return
  119. close_connection()
  120. if cancel and PAYLOAD:
  121. PAYLOAD.ack()
  122. PAYLOAD = None
  123. raise TaskCanceledException(msg)
  124. def collect():
  125. global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
  126. try:
  127. PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  128. if not PAYLOAD:
  129. PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
  130. if not PAYLOAD:
  131. time.sleep(1)
  132. return None
  133. except Exception:
  134. logging.exception("Get task event from queue exception")
  135. return None
  136. msg = PAYLOAD.get_message()
  137. if not msg:
  138. return None
  139. task = None
  140. canceled = False
  141. try:
  142. task = TaskService.get_task(msg["id"])
  143. if task:
  144. _, doc = DocumentService.get_by_id(task["doc_id"])
  145. canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  146. except DoesNotExist:
  147. pass
  148. except Exception:
  149. logging.exception("collect get_task exception")
  150. if not task or canceled:
  151. state = "is unknown" if not task else "has been cancelled"
  152. with mt_lock:
  153. DONE_TASKS += 1
  154. logging.info(f"collect task {msg['id']} {state}")
  155. return None
  156. if msg.get("type", "") == "raptor":
  157. task["task_type"] = "raptor"
  158. return task
  159. def get_storage_binary(bucket, name):
  160. return STORAGE_IMPL.get(bucket, name)
  161. def build_chunks(task, progress_callback):
  162. if task["size"] > DOC_MAXIMUM_SIZE:
  163. set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
  164. (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
  165. return []
  166. chunker = FACTORY[task["parser_id"].lower()]
  167. try:
  168. st = timer()
  169. bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
  170. binary = get_storage_binary(bucket, name)
  171. logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
  172. except TimeoutError:
  173. progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
  174. logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
  175. raise
  176. except Exception as e:
  177. if re.search("(No such file|not found)", str(e)):
  178. progress_callback(-1, "Can not find file <%s> from minio. Could you try it again?" % task["name"])
  179. else:
  180. progress_callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
  181. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  182. raise
  183. try:
  184. cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
  185. to_page=task["to_page"], lang=task["language"], callback=progress_callback,
  186. kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])
  187. logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
  188. except TaskCanceledException:
  189. raise
  190. except Exception as e:
  191. progress_callback(-1, "Internal server error while chunking: %s" % str(e).replace("'", ""))
  192. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  193. raise
  194. docs = []
  195. doc = {
  196. "doc_id": task["doc_id"],
  197. "kb_id": str(task["kb_id"])
  198. }
  199. if task["pagerank"]:
  200. doc["pagerank_fea"] = int(task["pagerank"])
  201. el = 0
  202. for ck in cks:
  203. d = copy.deepcopy(doc)
  204. d.update(ck)
  205. d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
  206. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  207. d["create_timestamp_flt"] = datetime.now().timestamp()
  208. if not d.get("image"):
  209. _ = d.pop("image", None)
  210. d["img_id"] = ""
  211. docs.append(d)
  212. continue
  213. try:
  214. output_buffer = BytesIO()
  215. if isinstance(d["image"], bytes):
  216. output_buffer = BytesIO(d["image"])
  217. else:
  218. d["image"].save(output_buffer, format='JPEG')
  219. st = timer()
  220. STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
  221. el += timer() - st
  222. except Exception:
  223. logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
  224. raise
  225. d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
  226. del d["image"]
  227. docs.append(d)
  228. logging.info("MINIO PUT({}):{}".format(task["name"], el))
  229. if task["parser_config"].get("auto_keywords", 0):
  230. st = timer()
  231. progress_callback(msg="Start to generate keywords for every chunk ...")
  232. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  233. for d in docs:
  234. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords",
  235. {"topn": task["parser_config"]["auto_keywords"]})
  236. if not cached:
  237. cached = keyword_extraction(chat_mdl, d["content_with_weight"],
  238. task["parser_config"]["auto_keywords"])
  239. if cached:
  240. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords",
  241. {"topn": task["parser_config"]["auto_keywords"]})
  242. d["important_kwd"] = cached.split(",")
  243. d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
  244. progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
  245. if task["parser_config"].get("auto_questions", 0):
  246. st = timer()
  247. progress_callback(msg="Start to generate questions for every chunk ...")
  248. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  249. for d in docs:
  250. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question",
  251. {"topn": task["parser_config"]["auto_questions"]})
  252. if not cached:
  253. cached = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"])
  254. if cached:
  255. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question",
  256. {"topn": task["parser_config"]["auto_questions"]})
  257. d["question_kwd"] = cached.split("\n")
  258. d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
  259. progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
  260. return docs
  261. def init_kb(row, vector_size: int):
  262. idxnm = search.index_name(row["tenant_id"])
  263. return settings.docStoreConn.createIdx(idxnm, row.get("kb_id",""), vector_size)
  264. def embedding(docs, mdl, parser_config=None, callback=None):
  265. if parser_config is None:
  266. parser_config = {}
  267. batch_size = 16
  268. tts, cnts = [], []
  269. for d in docs:
  270. tts.append(d.get("docnm_kwd", "Title"))
  271. c = "\n".join(d.get("question_kwd", []))
  272. if not c:
  273. c = d["content_with_weight"]
  274. c = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c)
  275. if not c:
  276. c = "None"
  277. cnts.append(c)
  278. tk_count = 0
  279. if len(tts) == len(cnts):
  280. tts_ = np.array([])
  281. for i in range(0, len(tts), batch_size):
  282. vts, c = mdl.encode(tts[i: i + batch_size])
  283. if len(tts_) == 0:
  284. tts_ = vts
  285. else:
  286. tts_ = np.concatenate((tts_, vts), axis=0)
  287. tk_count += c
  288. callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="")
  289. tts = tts_
  290. cnts_ = np.array([])
  291. for i in range(0, len(cnts), batch_size):
  292. vts, c = mdl.encode(cnts[i: i + batch_size])
  293. if len(cnts_) == 0:
  294. cnts_ = vts
  295. else:
  296. cnts_ = np.concatenate((cnts_, vts), axis=0)
  297. tk_count += c
  298. callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
  299. cnts = cnts_
  300. title_w = float(parser_config.get("filename_embd_weight", 0.1))
  301. vects = (title_w * tts + (1 - title_w) *
  302. cnts) if len(tts) == len(cnts) else cnts
  303. assert len(vects) == len(docs)
  304. vector_size = 0
  305. for i, d in enumerate(docs):
  306. v = vects[i].tolist()
  307. vector_size = len(v)
  308. d["q_%d_vec" % len(v)] = v
  309. return tk_count, vector_size
  310. def run_raptor(row, chat_mdl, embd_mdl, callback=None):
  311. vts, _ = embd_mdl.encode(["ok"])
  312. vector_size = len(vts[0])
  313. vctr_nm = "q_%d_vec" % vector_size
  314. chunks = []
  315. for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
  316. fields=["content_with_weight", vctr_nm]):
  317. chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
  318. raptor = Raptor(
  319. row["parser_config"]["raptor"].get("max_cluster", 64),
  320. chat_mdl,
  321. embd_mdl,
  322. row["parser_config"]["raptor"]["prompt"],
  323. row["parser_config"]["raptor"]["max_token"],
  324. row["parser_config"]["raptor"]["threshold"]
  325. )
  326. original_length = len(chunks)
  327. chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
  328. doc = {
  329. "doc_id": row["doc_id"],
  330. "kb_id": [str(row["kb_id"])],
  331. "docnm_kwd": row["name"],
  332. "title_tks": rag_tokenizer.tokenize(row["name"])
  333. }
  334. if row["pagerank"]:
  335. doc["pagerank_fea"] = int(row["pagerank"])
  336. res = []
  337. tk_count = 0
  338. for content, vctr in chunks[original_length:]:
  339. d = copy.deepcopy(doc)
  340. d["id"] = xxhash.xxh64((content + str(d["doc_id"])).encode("utf-8")).hexdigest()
  341. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  342. d["create_timestamp_flt"] = datetime.now().timestamp()
  343. d[vctr_nm] = vctr.tolist()
  344. d["content_with_weight"] = content
  345. d["content_ltks"] = rag_tokenizer.tokenize(content)
  346. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  347. res.append(d)
  348. tk_count += num_tokens_from_string(content)
  349. return res, tk_count, vector_size
  350. def do_handle_task(task):
  351. task_id = task["id"]
  352. task_from_page = task["from_page"]
  353. task_to_page = task["to_page"]
  354. task_tenant_id = task["tenant_id"]
  355. task_embedding_id = task["embd_id"]
  356. task_language = task["language"]
  357. task_llm_id = task["llm_id"]
  358. task_dataset_id = task["kb_id"]
  359. task_doc_id = task["doc_id"]
  360. task_document_name = task["name"]
  361. task_parser_config = task["parser_config"]
  362. # prepare the progress callback function
  363. progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
  364. try:
  365. task_canceled = TaskService.do_cancel(task_id)
  366. except DoesNotExist:
  367. logging.warning(f"task {task_id} is unknown")
  368. return
  369. if task_canceled:
  370. progress_callback(-1, msg="Task has been canceled.")
  371. return
  372. try:
  373. # bind embedding model
  374. embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
  375. except Exception as e:
  376. error_message = f'Fail to bind embedding model: {str(e)}'
  377. progress_callback(-1, msg=error_message)
  378. logging.exception(error_message)
  379. raise
  380. # Either using RAPTOR or Standard chunking methods
  381. if task.get("task_type", "") == "raptor":
  382. try:
  383. # bind LLM for raptor
  384. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  385. # run RAPTOR
  386. chunks, token_count, vector_size = run_raptor(task, chat_model, embedding_model, progress_callback)
  387. except TaskCanceledException:
  388. raise
  389. except Exception as e:
  390. error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
  391. progress_callback(-1, msg=error_message)
  392. logging.exception(error_message)
  393. raise
  394. else:
  395. # Standard chunking methods
  396. start_ts = timer()
  397. chunks = build_chunks(task, progress_callback)
  398. logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
  399. if chunks is None:
  400. return
  401. if not chunks:
  402. progress_callback(1., msg=f"No chunk built from {task_document_name}")
  403. return
  404. # TODO: exception handler
  405. ## set_progress(task["did"], -1, "ERROR: ")
  406. progress_callback(msg="Generate {} chunks".format(len(chunks)))
  407. start_ts = timer()
  408. try:
  409. token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
  410. except Exception as e:
  411. error_message = "Generate embedding error:{}".format(str(e))
  412. progress_callback(-1, error_message)
  413. logging.exception(error_message)
  414. token_count = 0
  415. raise
  416. progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
  417. logging.info(progress_message)
  418. progress_callback(msg=progress_message)
  419. # logging.info(f"task_executor init_kb index {search.index_name(task_tenant_id)} embedding_model {embedding_model.llm_name} vector length {vector_size}")
  420. init_kb(task, vector_size)
  421. chunk_count = len(set([chunk["id"] for chunk in chunks]))
  422. start_ts = timer()
  423. doc_store_result = ""
  424. es_bulk_size = 4
  425. for b in range(0, len(chunks), es_bulk_size):
  426. doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id)
  427. if b % 128 == 0:
  428. progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
  429. if doc_store_result:
  430. error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
  431. progress_callback(-1, msg=error_message)
  432. raise Exception(error_message)
  433. chunk_ids = [chunk["id"] for chunk in chunks[:b + es_bulk_size]]
  434. chunk_ids_str = " ".join(chunk_ids)
  435. try:
  436. TaskService.update_chunk_ids(task["id"], chunk_ids_str)
  437. except DoesNotExist:
  438. logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
  439. doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)
  440. return
  441. logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts))
  442. DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
  443. time_cost = timer() - start_ts
  444. progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
  445. logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, time_cost))
  446. def handle_task():
  447. global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  448. task = collect()
  449. if task:
  450. try:
  451. logging.info(f"handle_task begin for task {json.dumps(task)}")
  452. with mt_lock:
  453. CURRENT_TASK = copy.deepcopy(task)
  454. do_handle_task(task)
  455. with mt_lock:
  456. DONE_TASKS += 1
  457. CURRENT_TASK = None
  458. logging.info(f"handle_task done for task {json.dumps(task)}")
  459. except TaskCanceledException:
  460. with mt_lock:
  461. DONE_TASKS += 1
  462. CURRENT_TASK = None
  463. try:
  464. set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException")
  465. except Exception:
  466. pass
  467. logging.debug("handle_task got TaskCanceledException", exc_info=True)
  468. except Exception as e:
  469. with mt_lock:
  470. FAILED_TASKS += 1
  471. CURRENT_TASK = None
  472. try:
  473. set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
  474. except Exception:
  475. pass
  476. logging.exception(f"handle_task got exception for task {json.dumps(task)}")
  477. if PAYLOAD:
  478. PAYLOAD.ack()
  479. PAYLOAD = None
  480. def report_status():
  481. global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  482. REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
  483. while True:
  484. try:
  485. now = datetime.now()
  486. group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  487. if group_info is not None:
  488. PENDING_TASKS = int(group_info.get("pending", 0))
  489. LAG_TASKS = int(group_info.get("lag", 0))
  490. with mt_lock:
  491. heartbeat = json.dumps({
  492. "name": CONSUMER_NAME,
  493. "now": now.astimezone().isoformat(timespec="milliseconds"),
  494. "boot_at": BOOT_AT,
  495. "pending": PENDING_TASKS,
  496. "lag": LAG_TASKS,
  497. "done": DONE_TASKS,
  498. "failed": FAILED_TASKS,
  499. "current": CURRENT_TASK,
  500. })
  501. REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
  502. logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
  503. expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
  504. if expired > 0:
  505. REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
  506. except Exception:
  507. logging.exception("report_status got exception")
  508. time.sleep(30)
  509. def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
  510. msg = ""
  511. if dump_full:
  512. stats2 = snapshot2.statistics('lineno')
  513. msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
  514. for stat in stats2[:10]:
  515. msg += f"{stat}\n"
  516. stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
  517. msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
  518. for stat in stats1_vs_2[:10]:
  519. msg += f"{stat}\n"
  520. msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
  521. for stat in stats1_vs_2[:3]:
  522. msg += '\n'.join(stat.traceback.format())
  523. logging.info(msg)
  524. def main():
  525. logging.info(r"""
  526. ______ __ ______ __
  527. /_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
  528. / / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/
  529. / / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
  530. /_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
  531. """)
  532. logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
  533. settings.init_settings()
  534. print_rag_settings()
  535. background_thread = threading.Thread(target=report_status)
  536. background_thread.daemon = True
  537. background_thread.start()
  538. TRACE_MALLOC_DELTA = int(os.environ.get('TRACE_MALLOC_DELTA', "0"))
  539. TRACE_MALLOC_FULL = int(os.environ.get('TRACE_MALLOC_FULL', "0"))
  540. if TRACE_MALLOC_DELTA > 0:
  541. if TRACE_MALLOC_FULL < TRACE_MALLOC_DELTA:
  542. TRACE_MALLOC_FULL = TRACE_MALLOC_DELTA
  543. tracemalloc.start()
  544. snapshot1 = tracemalloc.take_snapshot()
  545. while True:
  546. handle_task()
  547. num_tasks = DONE_TASKS + FAILED_TASKS
  548. if TRACE_MALLOC_DELTA > 0 and num_tasks > 0 and num_tasks % TRACE_MALLOC_DELTA == 0:
  549. snapshot2 = tracemalloc.take_snapshot()
  550. analyze_heap(snapshot1, snapshot2, int(num_tasks / TRACE_MALLOC_DELTA), num_tasks % TRACE_MALLOC_FULL == 0)
  551. snapshot1 = snapshot2
  552. snapshot2 = None
  553. if __name__ == "__main__":
  554. main()