Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

task_executor.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # from beartype import BeartypeConf
  16. # from beartype.claw import beartype_all # <-- you didn't sign up for this
  17. # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
  18. import logging
  19. import sys
  20. import os
  21. from api.utils.log_utils import initRootLogger
  22. from datetime import datetime
  23. import json
  24. import hashlib
  25. import copy
  26. import re
  27. import time
  28. import threading
  29. from functools import partial
  30. from io import BytesIO
  31. from multiprocessing.context import TimeoutError
  32. from timeit import default_timer as timer
  33. import tracemalloc
  34. import numpy as np
  35. from api.db import LLMType, ParserType
  36. from api.db.services.dialog_service import keyword_extraction, question_proposal
  37. from api.db.services.document_service import DocumentService
  38. from api.db.services.llm_service import LLMBundle
  39. from api.db.services.task_service import TaskService
  40. from api.db.services.file2document_service import File2DocumentService
  41. from api import settings
  42. from api.versions import get_ragflow_version
  43. from api.db.db_models import close_connection
  44. from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
  45. knowledge_graph, email
  46. from rag.nlp import search, rag_tokenizer
  47. from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
  48. from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings
  49. from rag.utils import rmSpace, num_tokens_from_string
  50. from rag.utils.redis_conn import REDIS_CONN, Payload
  51. from rag.utils.storage_factory import STORAGE_IMPL
  52. CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
  53. CONSUMER_NAME = "task_executor_" + CONSUMER_NO
  54. LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
  55. initRootLogger(CONSUMER_NAME, LOG_LEVELS)
  56. BATCH_SIZE = 64
  57. FACTORY = {
  58. "general": naive,
  59. ParserType.NAIVE.value: naive,
  60. ParserType.PAPER.value: paper,
  61. ParserType.BOOK.value: book,
  62. ParserType.PRESENTATION.value: presentation,
  63. ParserType.MANUAL.value: manual,
  64. ParserType.LAWS.value: laws,
  65. ParserType.QA.value: qa,
  66. ParserType.TABLE.value: table,
  67. ParserType.RESUME.value: resume,
  68. ParserType.PICTURE.value: picture,
  69. ParserType.ONE.value: one,
  70. ParserType.AUDIO.value: audio,
  71. ParserType.EMAIL.value: email,
  72. ParserType.KG.value: knowledge_graph
  73. }
  74. CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
  75. PAYLOAD: Payload | None = None
  76. BOOT_AT = datetime.now().isoformat()
  77. PENDING_TASKS = 0
  78. LAG_TASKS = 0
  79. mt_lock = threading.Lock()
  80. DONE_TASKS = 0
  81. FAILED_TASKS = 0
  82. CURRENT_TASK = None
  83. def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
  84. global PAYLOAD
  85. if prog is not None and prog < 0:
  86. msg = "[ERROR]" + msg
  87. cancel = TaskService.do_cancel(task_id)
  88. if cancel:
  89. msg += " [Canceled]"
  90. prog = -1
  91. if to_page > 0:
  92. if msg:
  93. msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
  94. d = {"progress_msg": msg}
  95. if prog is not None:
  96. d["progress"] = prog
  97. try:
  98. logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
  99. TaskService.update_progress(task_id, d)
  100. except Exception:
  101. logging.exception(f"set_progress({task_id}) got exception")
  102. close_connection()
  103. if cancel:
  104. if PAYLOAD:
  105. PAYLOAD.ack()
  106. PAYLOAD = None
  107. os._exit(0)
  108. def collect():
  109. global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
  110. try:
  111. PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  112. if not PAYLOAD:
  113. PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
  114. if not PAYLOAD:
  115. time.sleep(1)
  116. return None
  117. except Exception:
  118. logging.exception("Get task event from queue exception")
  119. return None
  120. msg = PAYLOAD.get_message()
  121. if not msg:
  122. return None
  123. if TaskService.do_cancel(msg["id"]):
  124. with mt_lock:
  125. DONE_TASKS += 1
  126. logging.info("Task {} has been canceled.".format(msg["id"]))
  127. return None
  128. task = TaskService.get_task(msg["id"])
  129. if not task:
  130. with mt_lock:
  131. DONE_TASKS += 1
  132. logging.warning("{} empty task!".format(msg["id"]))
  133. return None
  134. if msg.get("type", "") == "raptor":
  135. task["task_type"] = "raptor"
  136. return task
  137. def get_storage_binary(bucket, name):
  138. return STORAGE_IMPL.get(bucket, name)
  139. def build_chunks(task, progress_callback):
  140. if task["size"] > DOC_MAXIMUM_SIZE:
  141. set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
  142. (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
  143. return []
  144. chunker = FACTORY[task["parser_id"].lower()]
  145. try:
  146. st = timer()
  147. bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
  148. binary = get_storage_binary(bucket, name)
  149. logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
  150. except TimeoutError:
  151. progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
  152. logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
  153. raise
  154. except Exception as e:
  155. if re.search("(No such file|not found)", str(e)):
  156. progress_callback(-1, "Can not find file <%s> from minio. Could you try it again?" % task["name"])
  157. else:
  158. progress_callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
  159. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  160. raise
  161. try:
  162. cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
  163. to_page=task["to_page"], lang=task["language"], callback=progress_callback,
  164. kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])
  165. logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
  166. except Exception as e:
  167. progress_callback(-1, "Internal server error while chunking: %s" % str(e).replace("'", ""))
  168. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  169. raise
  170. docs = []
  171. doc = {
  172. "doc_id": task["doc_id"],
  173. "kb_id": str(task["kb_id"])
  174. }
  175. if task["pagerank"]:
  176. doc["pagerank_fea"] = int(task["pagerank"])
  177. el = 0
  178. for ck in cks:
  179. d = copy.deepcopy(doc)
  180. d.update(ck)
  181. md5 = hashlib.md5()
  182. md5.update((ck["content_with_weight"] +
  183. str(d["doc_id"])).encode("utf-8"))
  184. d["id"] = md5.hexdigest()
  185. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  186. d["create_timestamp_flt"] = datetime.now().timestamp()
  187. if not d.get("image"):
  188. _ = d.pop("image", None)
  189. d["img_id"] = ""
  190. d["page_num_list"] = json.dumps([])
  191. d["position_list"] = json.dumps([])
  192. d["top_list"] = json.dumps([])
  193. docs.append(d)
  194. continue
  195. try:
  196. output_buffer = BytesIO()
  197. if isinstance(d["image"], bytes):
  198. output_buffer = BytesIO(d["image"])
  199. else:
  200. d["image"].save(output_buffer, format='JPEG')
  201. st = timer()
  202. STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
  203. el += timer() - st
  204. except Exception:
  205. logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["_id"]))
  206. raise
  207. d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
  208. del d["image"]
  209. docs.append(d)
  210. logging.info("MINIO PUT({}):{}".format(task["name"], el))
  211. if task["parser_config"].get("auto_keywords", 0):
  212. st = timer()
  213. progress_callback(msg="Start to generate keywords for every chunk ...")
  214. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  215. for d in docs:
  216. d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
  217. task["parser_config"]["auto_keywords"]).split(",")
  218. d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
  219. progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
  220. if task["parser_config"].get("auto_questions", 0):
  221. st = timer()
  222. progress_callback(msg="Start to generate questions for every chunk ...")
  223. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  224. for d in docs:
  225. d["question_kwd"] = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"]).split("\n")
  226. d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
  227. progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
  228. return docs
  229. def init_kb(row, vector_size: int):
  230. idxnm = search.index_name(row["tenant_id"])
  231. return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
  232. def embedding(docs, mdl, parser_config=None, callback=None):
  233. if parser_config is None:
  234. parser_config = {}
  235. batch_size = 16
  236. tts, cnts = [], []
  237. for d in docs:
  238. tts.append(rmSpace(d["title_tks"]))
  239. c = "\n".join(d.get("question_kwd", []))
  240. if not c:
  241. c = d["content_with_weight"]
  242. c = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c)
  243. cnts.append(c)
  244. tk_count = 0
  245. if len(tts) == len(cnts):
  246. tts_ = np.array([])
  247. for i in range(0, len(tts), batch_size):
  248. vts, c = mdl.encode(tts[i: i + batch_size])
  249. if len(tts_) == 0:
  250. tts_ = vts
  251. else:
  252. tts_ = np.concatenate((tts_, vts), axis=0)
  253. tk_count += c
  254. callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="")
  255. tts = tts_
  256. cnts_ = np.array([])
  257. for i in range(0, len(cnts), batch_size):
  258. vts, c = mdl.encode(cnts[i: i + batch_size])
  259. if len(cnts_) == 0:
  260. cnts_ = vts
  261. else:
  262. cnts_ = np.concatenate((cnts_, vts), axis=0)
  263. tk_count += c
  264. callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
  265. cnts = cnts_
  266. title_w = float(parser_config.get("filename_embd_weight", 0.1))
  267. vects = (title_w * tts + (1 - title_w) *
  268. cnts) if len(tts) == len(cnts) else cnts
  269. assert len(vects) == len(docs)
  270. vector_size = 0
  271. for i, d in enumerate(docs):
  272. v = vects[i].tolist()
  273. vector_size = len(v)
  274. d["q_%d_vec" % len(v)] = v
  275. return tk_count, vector_size
  276. def run_raptor(row, chat_mdl, embd_mdl, callback=None):
  277. vts, _ = embd_mdl.encode(["ok"])
  278. vector_size = len(vts[0])
  279. vctr_nm = "q_%d_vec" % vector_size
  280. chunks = []
  281. for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
  282. fields=["content_with_weight", vctr_nm]):
  283. chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
  284. raptor = Raptor(
  285. row["parser_config"]["raptor"].get("max_cluster", 64),
  286. chat_mdl,
  287. embd_mdl,
  288. row["parser_config"]["raptor"]["prompt"],
  289. row["parser_config"]["raptor"]["max_token"],
  290. row["parser_config"]["raptor"]["threshold"]
  291. )
  292. original_length = len(chunks)
  293. chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
  294. doc = {
  295. "doc_id": row["doc_id"],
  296. "kb_id": [str(row["kb_id"])],
  297. "docnm_kwd": row["name"],
  298. "title_tks": rag_tokenizer.tokenize(row["name"])
  299. }
  300. if row["pagerank"]:
  301. doc["pagerank_fea"] = int(row["pagerank"])
  302. res = []
  303. tk_count = 0
  304. for content, vctr in chunks[original_length:]:
  305. d = copy.deepcopy(doc)
  306. md5 = hashlib.md5()
  307. md5.update((content + str(d["doc_id"])).encode("utf-8"))
  308. d["id"] = md5.hexdigest()
  309. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  310. d["create_timestamp_flt"] = datetime.now().timestamp()
  311. d[vctr_nm] = vctr.tolist()
  312. d["content_with_weight"] = content
  313. d["content_ltks"] = rag_tokenizer.tokenize(content)
  314. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  315. res.append(d)
  316. tk_count += num_tokens_from_string(content)
  317. return res, tk_count, vector_size
  318. def do_handle_task(task):
  319. task_id = task["id"]
  320. task_from_page = task["from_page"]
  321. task_to_page = task["to_page"]
  322. task_tenant_id = task["tenant_id"]
  323. task_embedding_id = task["embd_id"]
  324. task_language = task["language"]
  325. task_llm_id = task["llm_id"]
  326. task_dataset_id = task["kb_id"]
  327. task_doc_id = task["doc_id"]
  328. task_document_name = task["name"]
  329. task_parser_config = task["parser_config"]
  330. # prepare the progress callback function
  331. progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
  332. try:
  333. # bind embedding model
  334. embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
  335. except Exception as e:
  336. error_message = f'Fail to bind embedding model: {str(e)}'
  337. progress_callback(-1, msg=error_message)
  338. logging.exception(error_message)
  339. raise
  340. # Either using RAPTOR or Standard chunking methods
  341. if task.get("task_type", "") == "raptor":
  342. try:
  343. # bind LLM for raptor
  344. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  345. # run RAPTOR
  346. chunks, token_count, vector_size = run_raptor(task, chat_model, embedding_model, progress_callback)
  347. except Exception as e:
  348. error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
  349. progress_callback(-1, msg=error_message)
  350. logging.exception(error_message)
  351. raise
  352. else:
  353. # Standard chunking methods
  354. start_ts = timer()
  355. chunks = build_chunks(task, progress_callback)
  356. logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
  357. if chunks is None:
  358. return
  359. if not chunks:
  360. progress_callback(1., msg=f"No chunk built from {task_document_name}")
  361. return
  362. # TODO: exception handler
  363. ## set_progress(task["did"], -1, "ERROR: ")
  364. progress_callback(msg="Generate {} chunks".format(len(chunks)))
  365. start_ts = timer()
  366. try:
  367. token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
  368. except Exception as e:
  369. error_message = "Generate embedding error:{}".format(str(e))
  370. progress_callback(-1, error_message)
  371. logging.exception(error_message)
  372. token_count = 0
  373. raise
  374. progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
  375. logging.info(progress_message)
  376. progress_callback(msg=progress_message)
  377. # logging.info(f"task_executor init_kb index {search.index_name(task_tenant_id)} embedding_model {embedding_model.llm_name} vector length {vector_size}")
  378. init_kb(task, vector_size)
  379. chunk_count = len(set([chunk["id"] for chunk in chunks]))
  380. start_ts = timer()
  381. doc_store_result = ""
  382. es_bulk_size = 4
  383. for b in range(0, len(chunks), es_bulk_size):
  384. doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id)
  385. if b % 128 == 0:
  386. progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
  387. logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
  388. if doc_store_result:
  389. error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
  390. progress_callback(-1, msg=error_message)
  391. settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
  392. logging.error(error_message)
  393. raise Exception(error_message)
  394. if TaskService.do_cancel(task_id):
  395. settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
  396. return
  397. DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
  398. time_cost = timer() - start_ts
  399. progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
  400. logging.info("Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(task_id, token_count, len(chunks), time_cost))
  401. def handle_task():
  402. global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  403. task = collect()
  404. if task:
  405. try:
  406. logging.info(f"handle_task begin for task {json.dumps(task)}")
  407. with mt_lock:
  408. CURRENT_TASK = copy.deepcopy(task)
  409. do_handle_task(task)
  410. with mt_lock:
  411. DONE_TASKS += 1
  412. CURRENT_TASK = None
  413. logging.info(f"handle_task done for task {json.dumps(task)}")
  414. except Exception:
  415. with mt_lock:
  416. FAILED_TASKS += 1
  417. CURRENT_TASK = None
  418. logging.exception(f"handle_task got exception for task {json.dumps(task)}")
  419. if PAYLOAD:
  420. PAYLOAD.ack()
  421. PAYLOAD = None
  422. def report_status():
  423. global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  424. REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
  425. while True:
  426. try:
  427. now = datetime.now()
  428. group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  429. if group_info is not None:
  430. PENDING_TASKS = int(group_info["pending"])
  431. LAG_TASKS = int(group_info["lag"])
  432. with mt_lock:
  433. heartbeat = json.dumps({
  434. "name": CONSUMER_NAME,
  435. "now": now.isoformat(),
  436. "boot_at": BOOT_AT,
  437. "pending": PENDING_TASKS,
  438. "lag": LAG_TASKS,
  439. "done": DONE_TASKS,
  440. "failed": FAILED_TASKS,
  441. "current": CURRENT_TASK,
  442. })
  443. REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
  444. logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
  445. expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
  446. if expired > 0:
  447. REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
  448. except Exception:
  449. logging.exception("report_status got exception")
  450. time.sleep(30)
  451. def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
  452. msg = ""
  453. if dump_full:
  454. stats2 = snapshot2.statistics('lineno')
  455. msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
  456. for stat in stats2[:10]:
  457. msg += f"{stat}\n"
  458. stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
  459. msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
  460. for stat in stats1_vs_2[:10]:
  461. msg += f"{stat}\n"
  462. msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
  463. for stat in stats1_vs_2[:3]:
  464. msg += '\n'.join(stat.traceback.format())
  465. logging.info(msg)
  466. def main():
  467. logging.info(r"""
  468. ______ __ ______ __
  469. /_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
  470. / / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/
  471. / / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
  472. /_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
  473. """)
  474. logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
  475. settings.init_settings()
  476. print_rag_settings()
  477. background_thread = threading.Thread(target=report_status)
  478. background_thread.daemon = True
  479. background_thread.start()
  480. TRACE_MALLOC_DELTA = int(os.environ.get('TRACE_MALLOC_DELTA', "0"))
  481. TRACE_MALLOC_FULL = int(os.environ.get('TRACE_MALLOC_FULL', "0"))
  482. if TRACE_MALLOC_DELTA > 0:
  483. if TRACE_MALLOC_FULL < TRACE_MALLOC_DELTA:
  484. TRACE_MALLOC_FULL = TRACE_MALLOC_DELTA
  485. tracemalloc.start()
  486. snapshot1 = tracemalloc.take_snapshot()
  487. while True:
  488. handle_task()
  489. num_tasks = DONE_TASKS + FAILED_TASKS
  490. if TRACE_MALLOC_DELTA > 0 and num_tasks > 0 and num_tasks % TRACE_MALLOC_DELTA == 0:
  491. snapshot2 = tracemalloc.take_snapshot()
  492. analyze_heap(snapshot1, snapshot2, int(num_tasks / TRACE_MALLOC_DELTA), num_tasks % TRACE_MALLOC_FULL == 0)
  493. snapshot1 = snapshot2
  494. snapshot2 = None
  495. if __name__ == "__main__":
  496. main()