Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

task_executor.py 30KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # from beartype import BeartypeConf
  16. # from beartype.claw import beartype_all # <-- you didn't sign up for this
  17. # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
  18. import random
  19. import sys
  20. from api.utils.log_utils import initRootLogger
  21. from graphrag.general.index import WithCommunity, WithResolution, Dealer
  22. from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
  23. from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
  24. from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
  25. CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
  26. CONSUMER_NAME = "task_executor_" + CONSUMER_NO
  27. initRootLogger(CONSUMER_NAME)
  28. import logging
  29. import os
  30. from datetime import datetime
  31. import json
  32. import xxhash
  33. import copy
  34. import re
  35. import time
  36. import threading
  37. from functools import partial
  38. from io import BytesIO
  39. from multiprocessing.context import TimeoutError
  40. from timeit import default_timer as timer
  41. import tracemalloc
  42. import numpy as np
  43. from peewee import DoesNotExist
  44. from api.db import LLMType, ParserType, TaskStatus
  45. from api.db.services.dialog_service import keyword_extraction, question_proposal, content_tagging
  46. from api.db.services.document_service import DocumentService
  47. from api.db.services.llm_service import LLMBundle
  48. from api.db.services.task_service import TaskService
  49. from api.db.services.file2document_service import File2DocumentService
  50. from api import settings
  51. from api.versions import get_ragflow_version
  52. from api.db.db_models import close_connection
  53. from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
  54. email, tag
  55. from rag.nlp import search, rag_tokenizer
  56. from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
  57. from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
  58. from rag.utils import num_tokens_from_string
  59. from rag.utils.redis_conn import REDIS_CONN, Payload
  60. from rag.utils.storage_factory import STORAGE_IMPL
  61. BATCH_SIZE = 64
  62. FACTORY = {
  63. "general": naive,
  64. ParserType.NAIVE.value: naive,
  65. ParserType.PAPER.value: paper,
  66. ParserType.BOOK.value: book,
  67. ParserType.PRESENTATION.value: presentation,
  68. ParserType.MANUAL.value: manual,
  69. ParserType.LAWS.value: laws,
  70. ParserType.QA.value: qa,
  71. ParserType.TABLE.value: table,
  72. ParserType.RESUME.value: resume,
  73. ParserType.PICTURE.value: picture,
  74. ParserType.ONE.value: one,
  75. ParserType.AUDIO.value: audio,
  76. ParserType.EMAIL.value: email,
  77. ParserType.KG.value: naive,
  78. ParserType.TAG.value: tag
  79. }
  80. CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
  81. PAYLOAD: Payload | None = None
  82. BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
  83. PENDING_TASKS = 0
  84. LAG_TASKS = 0
  85. mt_lock = threading.Lock()
  86. DONE_TASKS = 0
  87. FAILED_TASKS = 0
  88. CURRENT_TASK = None
  89. class TaskCanceledException(Exception):
  90. def __init__(self, msg):
  91. self.msg = msg
  92. def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
  93. global PAYLOAD
  94. if prog is not None and prog < 0:
  95. msg = "[ERROR]" + msg
  96. try:
  97. cancel = TaskService.do_cancel(task_id)
  98. except DoesNotExist:
  99. logging.warning(f"set_progress task {task_id} is unknown")
  100. if PAYLOAD:
  101. PAYLOAD.ack()
  102. PAYLOAD = None
  103. return
  104. if cancel:
  105. msg += " [Canceled]"
  106. prog = -1
  107. if to_page > 0:
  108. if msg:
  109. if from_page < to_page:
  110. msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
  111. if msg:
  112. msg = datetime.now().strftime("%H:%M:%S") + " " + msg
  113. d = {"progress_msg": msg}
  114. if prog is not None:
  115. d["progress"] = prog
  116. logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
  117. try:
  118. TaskService.update_progress(task_id, d)
  119. except DoesNotExist:
  120. logging.warning(f"set_progress task {task_id} is unknown")
  121. if PAYLOAD:
  122. PAYLOAD.ack()
  123. PAYLOAD = None
  124. return
  125. close_connection()
  126. if cancel and PAYLOAD:
  127. PAYLOAD.ack()
  128. PAYLOAD = None
  129. raise TaskCanceledException(msg)
  130. def collect():
  131. global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
  132. try:
  133. PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  134. if not PAYLOAD:
  135. PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
  136. if not PAYLOAD:
  137. time.sleep(1)
  138. return None
  139. except Exception:
  140. logging.exception("Get task event from queue exception")
  141. return None
  142. msg = PAYLOAD.get_message()
  143. if not msg:
  144. return None
  145. task = None
  146. canceled = False
  147. try:
  148. task = TaskService.get_task(msg["id"])
  149. if task:
  150. _, doc = DocumentService.get_by_id(task["doc_id"])
  151. canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  152. except DoesNotExist:
  153. pass
  154. except Exception:
  155. logging.exception("collect get_task exception")
  156. if not task or canceled:
  157. state = "is unknown" if not task else "has been cancelled"
  158. with mt_lock:
  159. DONE_TASKS += 1
  160. logging.info(f"collect task {msg['id']} {state}")
  161. return None
  162. task["task_type"] = msg.get("task_type", "")
  163. return task
  164. def get_storage_binary(bucket, name):
  165. return STORAGE_IMPL.get(bucket, name)
  166. def build_chunks(task, progress_callback):
  167. if task["size"] > DOC_MAXIMUM_SIZE:
  168. set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
  169. (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
  170. return []
  171. chunker = FACTORY[task["parser_id"].lower()]
  172. try:
  173. st = timer()
  174. bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
  175. binary = get_storage_binary(bucket, name)
  176. logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
  177. except TimeoutError:
  178. progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
  179. logging.exception(
  180. "Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
  181. raise
  182. except Exception as e:
  183. if re.search("(No such file|not found)", str(e)):
  184. progress_callback(-1, "Can not find file <%s> from minio. Could you try it again?" % task["name"])
  185. else:
  186. progress_callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
  187. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  188. raise
  189. try:
  190. cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
  191. to_page=task["to_page"], lang=task["language"], callback=progress_callback,
  192. kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])
  193. logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
  194. except TaskCanceledException:
  195. raise
  196. except Exception as e:
  197. progress_callback(-1, "Internal server error while chunking: %s" % str(e).replace("'", ""))
  198. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  199. raise
  200. docs = []
  201. doc = {
  202. "doc_id": task["doc_id"],
  203. "kb_id": str(task["kb_id"])
  204. }
  205. if task["pagerank"]:
  206. doc[PAGERANK_FLD] = int(task["pagerank"])
  207. el = 0
  208. for ck in cks:
  209. d = copy.deepcopy(doc)
  210. d.update(ck)
  211. d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
  212. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  213. d["create_timestamp_flt"] = datetime.now().timestamp()
  214. if not d.get("image"):
  215. _ = d.pop("image", None)
  216. d["img_id"] = ""
  217. docs.append(d)
  218. continue
  219. try:
  220. output_buffer = BytesIO()
  221. if isinstance(d["image"], bytes):
  222. output_buffer = BytesIO(d["image"])
  223. else:
  224. d["image"].save(output_buffer, format='JPEG')
  225. st = timer()
  226. STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
  227. el += timer() - st
  228. except Exception:
  229. logging.exception(
  230. "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
  231. raise
  232. d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
  233. del d["image"]
  234. docs.append(d)
  235. logging.info("MINIO PUT({}):{}".format(task["name"], el))
  236. if task["parser_config"].get("auto_keywords", 0):
  237. st = timer()
  238. progress_callback(msg="Start to generate keywords for every chunk ...")
  239. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  240. for d in docs:
  241. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords",
  242. {"topn": task["parser_config"]["auto_keywords"]})
  243. if not cached:
  244. cached = keyword_extraction(chat_mdl, d["content_with_weight"],
  245. task["parser_config"]["auto_keywords"])
  246. if cached:
  247. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords",
  248. {"topn": task["parser_config"]["auto_keywords"]})
  249. d["important_kwd"] = cached.split(",")
  250. d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
  251. progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
  252. if task["parser_config"].get("auto_questions", 0):
  253. st = timer()
  254. progress_callback(msg="Start to generate questions for every chunk ...")
  255. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  256. for d in docs:
  257. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question",
  258. {"topn": task["parser_config"]["auto_questions"]})
  259. if not cached:
  260. cached = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"])
  261. if cached:
  262. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question",
  263. {"topn": task["parser_config"]["auto_questions"]})
  264. d["question_kwd"] = cached.split("\n")
  265. d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
  266. progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
  267. if task["kb_parser_config"].get("tag_kb_ids", []):
  268. progress_callback(msg="Start to tag for every chunk ...")
  269. kb_ids = task["kb_parser_config"]["tag_kb_ids"]
  270. tenant_id = task["tenant_id"]
  271. topn_tags = task["kb_parser_config"].get("topn_tags", 3)
  272. S = 1000
  273. st = timer()
  274. examples = []
  275. all_tags = get_tags_from_cache(kb_ids)
  276. if not all_tags:
  277. all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
  278. set_tags_to_cache(kb_ids, all_tags)
  279. else:
  280. all_tags = json.loads(all_tags)
  281. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  282. for d in docs:
  283. if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S):
  284. examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
  285. continue
  286. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
  287. if not cached:
  288. cached = content_tagging(chat_mdl, d["content_with_weight"], all_tags,
  289. random.choices(examples, k=2) if len(examples)>2 else examples,
  290. topn=topn_tags)
  291. if cached:
  292. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
  293. d[TAG_FLD] = json.loads(cached)
  294. progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
  295. return docs
  296. def init_kb(row, vector_size: int):
  297. idxnm = search.index_name(row["tenant_id"])
  298. return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
  299. def embedding(docs, mdl, parser_config=None, callback=None):
  300. if parser_config is None:
  301. parser_config = {}
  302. batch_size = 16
  303. tts, cnts = [], []
  304. for d in docs:
  305. tts.append(d.get("docnm_kwd", "Title"))
  306. c = "\n".join(d.get("question_kwd", []))
  307. if not c:
  308. c = d["content_with_weight"]
  309. c = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c)
  310. if not c:
  311. c = "None"
  312. cnts.append(c)
  313. tk_count = 0
  314. if len(tts) == len(cnts):
  315. vts, c = mdl.encode(tts[0: 1])
  316. tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
  317. tk_count += c
  318. cnts_ = np.array([])
  319. for i in range(0, len(cnts), batch_size):
  320. vts, c = mdl.encode(cnts[i: i + batch_size])
  321. if len(cnts_) == 0:
  322. cnts_ = vts
  323. else:
  324. cnts_ = np.concatenate((cnts_, vts), axis=0)
  325. tk_count += c
  326. callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
  327. cnts = cnts_
  328. title_w = float(parser_config.get("filename_embd_weight", 0.1))
  329. vects = (title_w * tts + (1 - title_w) *
  330. cnts) if len(tts) == len(cnts) else cnts
  331. assert len(vects) == len(docs)
  332. vector_size = 0
  333. for i, d in enumerate(docs):
  334. v = vects[i].tolist()
  335. vector_size = len(v)
  336. d["q_%d_vec" % len(v)] = v
  337. return tk_count, vector_size
  338. def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
  339. chunks = []
  340. vctr_nm = "q_%d_vec"%vector_size
  341. for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
  342. fields=["content_with_weight", vctr_nm]):
  343. chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
  344. raptor = Raptor(
  345. row["parser_config"]["raptor"].get("max_cluster", 64),
  346. chat_mdl,
  347. embd_mdl,
  348. row["parser_config"]["raptor"]["prompt"],
  349. row["parser_config"]["raptor"]["max_token"],
  350. row["parser_config"]["raptor"]["threshold"]
  351. )
  352. original_length = len(chunks)
  353. chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
  354. doc = {
  355. "doc_id": row["doc_id"],
  356. "kb_id": [str(row["kb_id"])],
  357. "docnm_kwd": row["name"],
  358. "title_tks": rag_tokenizer.tokenize(row["name"])
  359. }
  360. if row["pagerank"]:
  361. doc[PAGERANK_FLD] = int(row["pagerank"])
  362. res = []
  363. tk_count = 0
  364. for content, vctr in chunks[original_length:]:
  365. d = copy.deepcopy(doc)
  366. d["id"] = xxhash.xxh64((content + str(d["doc_id"])).encode("utf-8")).hexdigest()
  367. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  368. d["create_timestamp_flt"] = datetime.now().timestamp()
  369. d[vctr_nm] = vctr.tolist()
  370. d["content_with_weight"] = content
  371. d["content_ltks"] = rag_tokenizer.tokenize(content)
  372. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  373. res.append(d)
  374. tk_count += num_tokens_from_string(content)
  375. return res, tk_count
  376. def run_graphrag(row, chat_model, language, embedding_model, callback=None):
  377. chunks = []
  378. for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
  379. fields=["content_with_weight", "doc_id"]):
  380. chunks.append((d["doc_id"], d["content_with_weight"]))
  381. Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
  382. row["tenant_id"],
  383. str(row["kb_id"]),
  384. chat_model,
  385. chunks=chunks,
  386. language=language,
  387. entity_types=row["parser_config"]["graphrag"]["entity_types"],
  388. embed_bdl=embedding_model,
  389. callback=callback)
  390. def do_handle_task(task):
  391. task_id = task["id"]
  392. task_from_page = task["from_page"]
  393. task_to_page = task["to_page"]
  394. task_tenant_id = task["tenant_id"]
  395. task_embedding_id = task["embd_id"]
  396. task_language = task["language"]
  397. task_llm_id = task["llm_id"]
  398. task_dataset_id = task["kb_id"]
  399. task_doc_id = task["doc_id"]
  400. task_document_name = task["name"]
  401. task_parser_config = task["parser_config"]
  402. # prepare the progress callback function
  403. progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
  404. # FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
  405. lower_case_doc_engine = settings.DOC_ENGINE.lower()
  406. if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
  407. error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
  408. progress_callback(-1, msg=error_message)
  409. raise Exception(error_message)
  410. try:
  411. task_canceled = TaskService.do_cancel(task_id)
  412. except DoesNotExist:
  413. logging.warning(f"task {task_id} is unknown")
  414. return
  415. if task_canceled:
  416. progress_callback(-1, msg="Task has been canceled.")
  417. return
  418. try:
  419. # bind embedding model
  420. embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
  421. except Exception as e:
  422. error_message = f'Fail to bind embedding model: {str(e)}'
  423. progress_callback(-1, msg=error_message)
  424. logging.exception(error_message)
  425. raise
  426. vts, _ = embedding_model.encode(["ok"])
  427. vector_size = len(vts[0])
  428. init_kb(task, vector_size)
  429. # Either using RAPTOR or Standard chunking methods
  430. if task.get("task_type", "") == "raptor":
  431. try:
  432. # bind LLM for raptor
  433. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  434. # run RAPTOR
  435. chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
  436. except TaskCanceledException:
  437. raise
  438. except Exception as e:
  439. error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
  440. progress_callback(-1, msg=error_message)
  441. logging.exception(error_message)
  442. raise
  443. # Either using graphrag or Standard chunking methods
  444. elif task.get("task_type", "") == "graphrag":
  445. start_ts = timer()
  446. try:
  447. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  448. run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
  449. progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
  450. except TaskCanceledException:
  451. raise
  452. except Exception as e:
  453. error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
  454. progress_callback(-1, msg=error_message)
  455. logging.exception(error_message)
  456. raise
  457. return
  458. elif task.get("task_type", "") == "graph_resolution":
  459. start_ts = timer()
  460. try:
  461. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  462. WithResolution(
  463. task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
  464. progress_callback
  465. )
  466. progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
  467. except TaskCanceledException:
  468. raise
  469. except Exception as e:
  470. error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
  471. progress_callback(-1, msg=error_message)
  472. logging.exception(error_message)
  473. raise
  474. return
  475. elif task.get("task_type", "") == "graph_community":
  476. start_ts = timer()
  477. try:
  478. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  479. WithCommunity(
  480. task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
  481. progress_callback
  482. )
  483. progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
  484. except TaskCanceledException:
  485. raise
  486. except Exception as e:
  487. error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
  488. progress_callback(-1, msg=error_message)
  489. logging.exception(error_message)
  490. raise
  491. return
  492. else:
  493. # Standard chunking methods
  494. start_ts = timer()
  495. chunks = build_chunks(task, progress_callback)
  496. logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
  497. if chunks is None:
  498. return
  499. if not chunks:
  500. progress_callback(1., msg=f"No chunk built from {task_document_name}")
  501. return
  502. # TODO: exception handler
  503. ## set_progress(task["did"], -1, "ERROR: ")
  504. progress_callback(msg="Generate {} chunks".format(len(chunks)))
  505. start_ts = timer()
  506. try:
  507. token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
  508. except Exception as e:
  509. error_message = "Generate embedding error:{}".format(str(e))
  510. progress_callback(-1, error_message)
  511. logging.exception(error_message)
  512. token_count = 0
  513. raise
  514. progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
  515. logging.info(progress_message)
  516. progress_callback(msg=progress_message)
  517. chunk_count = len(set([chunk["id"] for chunk in chunks]))
  518. start_ts = timer()
  519. doc_store_result = ""
  520. es_bulk_size = 4
  521. for b in range(0, len(chunks), es_bulk_size):
  522. doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
  523. task_dataset_id)
  524. if b % 128 == 0:
  525. progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
  526. if doc_store_result:
  527. error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
  528. progress_callback(-1, msg=error_message)
  529. raise Exception(error_message)
  530. chunk_ids = [chunk["id"] for chunk in chunks[:b + es_bulk_size]]
  531. chunk_ids_str = " ".join(chunk_ids)
  532. try:
  533. TaskService.update_chunk_ids(task["id"], chunk_ids_str)
  534. except DoesNotExist:
  535. logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
  536. doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
  537. task_dataset_id)
  538. return
  539. logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
  540. task_to_page, len(chunks),
  541. timer() - start_ts))
  542. DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
  543. time_cost = timer() - start_ts
  544. progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
  545. logging.info(
  546. "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
  547. task_to_page, len(chunks),
  548. token_count, time_cost))
  549. def handle_task():
  550. global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  551. task = collect()
  552. if task:
  553. try:
  554. logging.info(f"handle_task begin for task {json.dumps(task)}")
  555. with mt_lock:
  556. CURRENT_TASK = copy.deepcopy(task)
  557. do_handle_task(task)
  558. with mt_lock:
  559. DONE_TASKS += 1
  560. CURRENT_TASK = None
  561. logging.info(f"handle_task done for task {json.dumps(task)}")
  562. except TaskCanceledException:
  563. with mt_lock:
  564. DONE_TASKS += 1
  565. CURRENT_TASK = None
  566. try:
  567. set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException")
  568. except Exception:
  569. pass
  570. logging.debug("handle_task got TaskCanceledException", exc_info=True)
  571. except Exception as e:
  572. with mt_lock:
  573. FAILED_TASKS += 1
  574. CURRENT_TASK = None
  575. try:
  576. set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
  577. except Exception:
  578. pass
  579. logging.exception(f"handle_task got exception for task {json.dumps(task)}")
  580. if PAYLOAD:
  581. PAYLOAD.ack()
  582. PAYLOAD = None
  583. def report_status():
  584. global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  585. REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
  586. while True:
  587. try:
  588. now = datetime.now()
  589. group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  590. if group_info is not None:
  591. PENDING_TASKS = int(group_info.get("pending", 0))
  592. LAG_TASKS = int(group_info.get("lag", 0))
  593. with mt_lock:
  594. heartbeat = json.dumps({
  595. "name": CONSUMER_NAME,
  596. "now": now.astimezone().isoformat(timespec="milliseconds"),
  597. "boot_at": BOOT_AT,
  598. "pending": PENDING_TASKS,
  599. "lag": LAG_TASKS,
  600. "done": DONE_TASKS,
  601. "failed": FAILED_TASKS,
  602. "current": CURRENT_TASK,
  603. })
  604. REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
  605. logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
  606. expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
  607. if expired > 0:
  608. REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
  609. except Exception:
  610. logging.exception("report_status got exception")
  611. time.sleep(30)
  612. def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
  613. msg = ""
  614. if dump_full:
  615. stats2 = snapshot2.statistics('lineno')
  616. msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
  617. for stat in stats2[:10]:
  618. msg += f"{stat}\n"
  619. stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
  620. msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
  621. for stat in stats1_vs_2[:10]:
  622. msg += f"{stat}\n"
  623. msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
  624. for stat in stats1_vs_2[:3]:
  625. msg += '\n'.join(stat.traceback.format())
  626. logging.info(msg)
  627. def main():
  628. logging.info(r"""
  629. ______ __ ______ __
  630. /_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
  631. / / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/
  632. / / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
  633. /_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
  634. """)
  635. logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
  636. settings.init_settings()
  637. print_rag_settings()
  638. background_thread = threading.Thread(target=report_status)
  639. background_thread.daemon = True
  640. background_thread.start()
  641. TRACE_MALLOC_DELTA = int(os.environ.get('TRACE_MALLOC_DELTA', "0"))
  642. TRACE_MALLOC_FULL = int(os.environ.get('TRACE_MALLOC_FULL', "0"))
  643. if TRACE_MALLOC_DELTA > 0:
  644. if TRACE_MALLOC_FULL < TRACE_MALLOC_DELTA:
  645. TRACE_MALLOC_FULL = TRACE_MALLOC_DELTA
  646. tracemalloc.start()
  647. snapshot1 = tracemalloc.take_snapshot()
  648. while True:
  649. handle_task()
  650. num_tasks = DONE_TASKS + FAILED_TASKS
  651. if TRACE_MALLOC_DELTA > 0 and num_tasks > 0 and num_tasks % TRACE_MALLOC_DELTA == 0:
  652. snapshot2 = tracemalloc.take_snapshot()
  653. analyze_heap(snapshot1, snapshot2, int(num_tasks / TRACE_MALLOC_DELTA), num_tasks % TRACE_MALLOC_FULL == 0)
  654. snapshot1 = snapshot2
  655. snapshot2 = None
  656. if __name__ == "__main__":
  657. main()