Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

task_executor.py 32KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # from beartype import BeartypeConf
  16. # from beartype.claw import beartype_all # <-- you didn't sign up for this
  17. # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
  18. import random
  19. import sys
  20. from api.utils.log_utils import initRootLogger, get_project_base_directory
  21. from graphrag.general.index import WithCommunity, WithResolution, Dealer
  22. from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
  23. from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
  24. from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
  25. from rag.prompts import keyword_extraction, question_proposal, content_tagging
  26. CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
  27. CONSUMER_NAME = "task_executor_" + CONSUMER_NO
  28. initRootLogger(CONSUMER_NAME)
  29. import asyncio
  30. import logging
  31. import os
  32. from datetime import datetime
  33. import json
  34. import xxhash
  35. import copy
  36. import re
  37. import time
  38. import threading
  39. from functools import partial
  40. from io import BytesIO
  41. from multiprocessing.context import TimeoutError
  42. from timeit import default_timer as timer
  43. import tracemalloc
  44. import signal
  45. import numpy as np
  46. from peewee import DoesNotExist
  47. from api.db import LLMType, ParserType, TaskStatus
  48. from api.db.services.document_service import DocumentService
  49. from api.db.services.llm_service import LLMBundle
  50. from api.db.services.task_service import TaskService
  51. from api.db.services.file2document_service import File2DocumentService
  52. from api import settings
  53. from api.versions import get_ragflow_version
  54. from api.db.db_models import close_connection
  55. from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
  56. email, tag
  57. from rag.nlp import search, rag_tokenizer
  58. from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
  59. from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
  60. from rag.utils import num_tokens_from_string
  61. from rag.utils.redis_conn import REDIS_CONN, Payload
  62. from rag.utils.storage_factory import STORAGE_IMPL
  63. BATCH_SIZE = 64
  64. FACTORY = {
  65. "general": naive,
  66. ParserType.NAIVE.value: naive,
  67. ParserType.PAPER.value: paper,
  68. ParserType.BOOK.value: book,
  69. ParserType.PRESENTATION.value: presentation,
  70. ParserType.MANUAL.value: manual,
  71. ParserType.LAWS.value: laws,
  72. ParserType.QA.value: qa,
  73. ParserType.TABLE.value: table,
  74. ParserType.RESUME.value: resume,
  75. ParserType.PICTURE.value: picture,
  76. ParserType.ONE.value: one,
  77. ParserType.AUDIO.value: audio,
  78. ParserType.EMAIL.value: email,
  79. ParserType.KG.value: naive,
  80. ParserType.TAG.value: tag
  81. }
  82. CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
  83. PAYLOAD: Payload | None = None
  84. BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
  85. PENDING_TASKS = 0
  86. LAG_TASKS = 0
  87. mt_lock = threading.Lock()
  88. DONE_TASKS = 0
  89. FAILED_TASKS = 0
  90. CURRENT_TASK = None
  91. tracemalloc_started = False
  92. # SIGUSR1 handler: start tracemalloc and take snapshot
  93. def start_tracemalloc_and_snapshot(signum, frame):
  94. global tracemalloc_started
  95. if not tracemalloc_started:
  96. logging.info("got SIGUSR1, start tracemalloc")
  97. tracemalloc.start()
  98. tracemalloc_started = True
  99. else:
  100. logging.info("got SIGUSR1, tracemalloc is already running")
  101. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  102. snapshot_file = f"snapshot_{timestamp}.trace"
  103. snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
  104. snapshot = tracemalloc.take_snapshot()
  105. snapshot.dump(snapshot_file)
  106. logging.info(f"taken snapshot {snapshot_file}")
  107. # SIGUSR2 handler: stop tracemalloc
  108. def stop_tracemalloc(signum, frame):
  109. global tracemalloc_started
  110. if tracemalloc_started:
  111. logging.info("go SIGUSR2, stop tracemalloc")
  112. tracemalloc.stop()
  113. tracemalloc_started = False
  114. else:
  115. logging.info("got SIGUSR2, tracemalloc not running")
  116. class TaskCanceledException(Exception):
  117. def __init__(self, msg):
  118. self.msg = msg
  119. def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
  120. global PAYLOAD
  121. if prog is not None and prog < 0:
  122. msg = "[ERROR]" + msg
  123. try:
  124. cancel = TaskService.do_cancel(task_id)
  125. except DoesNotExist:
  126. logging.warning(f"set_progress task {task_id} is unknown")
  127. if PAYLOAD:
  128. PAYLOAD.ack()
  129. PAYLOAD = None
  130. return
  131. if cancel:
  132. msg += " [Canceled]"
  133. prog = -1
  134. if to_page > 0:
  135. if msg:
  136. if from_page < to_page:
  137. msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
  138. if msg:
  139. msg = datetime.now().strftime("%H:%M:%S") + " " + msg
  140. d = {"progress_msg": msg}
  141. if prog is not None:
  142. d["progress"] = prog
  143. logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
  144. try:
  145. TaskService.update_progress(task_id, d)
  146. except DoesNotExist:
  147. logging.warning(f"set_progress task {task_id} is unknown")
  148. if PAYLOAD:
  149. PAYLOAD.ack()
  150. PAYLOAD = None
  151. return
  152. close_connection()
  153. if cancel and PAYLOAD:
  154. PAYLOAD.ack()
  155. PAYLOAD = None
  156. raise TaskCanceledException(msg)
  157. def collect():
  158. global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
  159. try:
  160. PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  161. if not PAYLOAD:
  162. PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
  163. if not PAYLOAD:
  164. time.sleep(1)
  165. return None
  166. except Exception:
  167. logging.exception("Get task event from queue exception")
  168. return None
  169. msg = PAYLOAD.get_message()
  170. if not msg:
  171. return None
  172. task = None
  173. canceled = False
  174. try:
  175. task = TaskService.get_task(msg["id"])
  176. if task:
  177. _, doc = DocumentService.get_by_id(task["doc_id"])
  178. canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  179. except DoesNotExist:
  180. pass
  181. except Exception:
  182. logging.exception("collect get_task exception")
  183. if not task or canceled:
  184. state = "is unknown" if not task else "has been cancelled"
  185. with mt_lock:
  186. DONE_TASKS += 1
  187. logging.info(f"collect task {msg['id']} {state}")
  188. return None
  189. task["task_type"] = msg.get("task_type", "")
  190. return task
  191. def get_storage_binary(bucket, name):
  192. return STORAGE_IMPL.get(bucket, name)
  193. def build_chunks(task, progress_callback):
  194. if task["size"] > DOC_MAXIMUM_SIZE:
  195. set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
  196. (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
  197. return []
  198. chunker = FACTORY[task["parser_id"].lower()]
  199. try:
  200. st = timer()
  201. bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
  202. binary = get_storage_binary(bucket, name)
  203. logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
  204. except TimeoutError:
  205. progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
  206. logging.exception(
  207. "Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
  208. raise
  209. except Exception as e:
  210. if re.search("(No such file|not found)", str(e)):
  211. progress_callback(-1, "Can not find file <%s> from minio. Could you try it again?" % task["name"])
  212. else:
  213. progress_callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
  214. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  215. raise
  216. try:
  217. cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
  218. to_page=task["to_page"], lang=task["language"], callback=progress_callback,
  219. kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])
  220. logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
  221. except TaskCanceledException:
  222. raise
  223. except Exception as e:
  224. progress_callback(-1, "Internal server error while chunking: %s" % str(e).replace("'", ""))
  225. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  226. raise
  227. docs = []
  228. doc = {
  229. "doc_id": task["doc_id"],
  230. "kb_id": str(task["kb_id"])
  231. }
  232. if task["pagerank"]:
  233. doc[PAGERANK_FLD] = int(task["pagerank"])
  234. el = 0
  235. for ck in cks:
  236. d = copy.deepcopy(doc)
  237. d.update(ck)
  238. d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
  239. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  240. d["create_timestamp_flt"] = datetime.now().timestamp()
  241. if not d.get("image"):
  242. _ = d.pop("image", None)
  243. d["img_id"] = ""
  244. docs.append(d)
  245. continue
  246. try:
  247. output_buffer = BytesIO()
  248. if isinstance(d["image"], bytes):
  249. output_buffer = BytesIO(d["image"])
  250. else:
  251. d["image"].save(output_buffer, format='JPEG')
  252. st = timer()
  253. STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
  254. el += timer() - st
  255. except Exception:
  256. logging.exception(
  257. "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
  258. raise
  259. d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
  260. del d["image"]
  261. docs.append(d)
  262. logging.info("MINIO PUT({}):{}".format(task["name"], el))
  263. if task["parser_config"].get("auto_keywords", 0):
  264. st = timer()
  265. progress_callback(msg="Start to generate keywords for every chunk ...")
  266. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  267. async def doc_keyword_extraction(chat_mdl, d, topn):
  268. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
  269. if not cached:
  270. cached = await asyncio.to_thread(keyword_extraction, chat_mdl, d["content_with_weight"], topn)
  271. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
  272. if cached:
  273. d["important_kwd"] = cached.split(",")
  274. d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
  275. return
  276. tasks = [doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]) for d in docs]
  277. asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
  278. progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
  279. if task["parser_config"].get("auto_questions", 0):
  280. st = timer()
  281. progress_callback(msg="Start to generate questions for every chunk ...")
  282. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  283. async def doc_question_proposal(chat_mdl, d, topn):
  284. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
  285. if not cached:
  286. cached = await asyncio.to_thread(question_proposal, chat_mdl, d["content_with_weight"], topn)
  287. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
  288. if cached:
  289. d["question_kwd"] = cached.split("\n")
  290. d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
  291. tasks = [doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]) for d in docs]
  292. asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
  293. progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
  294. if task["kb_parser_config"].get("tag_kb_ids", []):
  295. progress_callback(msg="Start to tag for every chunk ...")
  296. kb_ids = task["kb_parser_config"]["tag_kb_ids"]
  297. tenant_id = task["tenant_id"]
  298. topn_tags = task["kb_parser_config"].get("topn_tags", 3)
  299. S = 1000
  300. st = timer()
  301. examples = []
  302. all_tags = get_tags_from_cache(kb_ids)
  303. if not all_tags:
  304. all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
  305. set_tags_to_cache(kb_ids, all_tags)
  306. else:
  307. all_tags = json.loads(all_tags)
  308. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  309. docs_to_tag = []
  310. for d in docs:
  311. if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S):
  312. examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
  313. else:
  314. docs_to_tag.append(d)
  315. async def doc_content_tagging(chat_mdl, d, topn_tags):
  316. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
  317. if not cached:
  318. picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
  319. cached = await asyncio.to_thread(content_tagging, chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)
  320. if cached:
  321. cached = json.dumps(cached)
  322. if cached:
  323. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
  324. d[TAG_FLD] = json.loads(cached)
  325. tasks = [doc_content_tagging(chat_mdl, d, topn_tags) for d in docs_to_tag]
  326. asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
  327. progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
  328. return docs
  329. def init_kb(row, vector_size: int):
  330. idxnm = search.index_name(row["tenant_id"])
  331. return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
  332. def embedding(docs, mdl, parser_config=None, callback=None):
  333. if parser_config is None:
  334. parser_config = {}
  335. batch_size = 16
  336. tts, cnts = [], []
  337. for d in docs:
  338. tts.append(d.get("docnm_kwd", "Title"))
  339. c = "\n".join(d.get("question_kwd", []))
  340. if not c:
  341. c = d["content_with_weight"]
  342. c = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c)
  343. if not c:
  344. c = "None"
  345. cnts.append(c)
  346. tk_count = 0
  347. if len(tts) == len(cnts):
  348. vts, c = mdl.encode(tts[0: 1])
  349. tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
  350. tk_count += c
  351. cnts_ = np.array([])
  352. for i in range(0, len(cnts), batch_size):
  353. vts, c = mdl.encode(cnts[i: i + batch_size])
  354. if len(cnts_) == 0:
  355. cnts_ = vts
  356. else:
  357. cnts_ = np.concatenate((cnts_, vts), axis=0)
  358. tk_count += c
  359. callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
  360. cnts = cnts_
  361. title_w = float(parser_config.get("filename_embd_weight", 0.1))
  362. vects = (title_w * tts + (1 - title_w) *
  363. cnts) if len(tts) == len(cnts) else cnts
  364. assert len(vects) == len(docs)
  365. vector_size = 0
  366. for i, d in enumerate(docs):
  367. v = vects[i].tolist()
  368. vector_size = len(v)
  369. d["q_%d_vec" % len(v)] = v
  370. return tk_count, vector_size
  371. def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
  372. chunks = []
  373. vctr_nm = "q_%d_vec"%vector_size
  374. for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
  375. fields=["content_with_weight", vctr_nm]):
  376. chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
  377. raptor = Raptor(
  378. row["parser_config"]["raptor"].get("max_cluster", 64),
  379. chat_mdl,
  380. embd_mdl,
  381. row["parser_config"]["raptor"]["prompt"],
  382. row["parser_config"]["raptor"]["max_token"],
  383. row["parser_config"]["raptor"]["threshold"]
  384. )
  385. original_length = len(chunks)
  386. chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
  387. doc = {
  388. "doc_id": row["doc_id"],
  389. "kb_id": [str(row["kb_id"])],
  390. "docnm_kwd": row["name"],
  391. "title_tks": rag_tokenizer.tokenize(row["name"])
  392. }
  393. if row["pagerank"]:
  394. doc[PAGERANK_FLD] = int(row["pagerank"])
  395. res = []
  396. tk_count = 0
  397. for content, vctr in chunks[original_length:]:
  398. d = copy.deepcopy(doc)
  399. d["id"] = xxhash.xxh64((content + str(d["doc_id"])).encode("utf-8")).hexdigest()
  400. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  401. d["create_timestamp_flt"] = datetime.now().timestamp()
  402. d[vctr_nm] = vctr.tolist()
  403. d["content_with_weight"] = content
  404. d["content_ltks"] = rag_tokenizer.tokenize(content)
  405. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  406. res.append(d)
  407. tk_count += num_tokens_from_string(content)
  408. return res, tk_count
  409. def run_graphrag(row, chat_model, language, embedding_model, callback=None):
  410. chunks = []
  411. for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
  412. fields=["content_with_weight", "doc_id"]):
  413. chunks.append((d["doc_id"], d["content_with_weight"]))
  414. Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
  415. row["tenant_id"],
  416. str(row["kb_id"]),
  417. chat_model,
  418. chunks=chunks,
  419. language=language,
  420. entity_types=row["parser_config"]["graphrag"]["entity_types"],
  421. embed_bdl=embedding_model,
  422. callback=callback)
  423. def do_handle_task(task):
  424. task_id = task["id"]
  425. task_from_page = task["from_page"]
  426. task_to_page = task["to_page"]
  427. task_tenant_id = task["tenant_id"]
  428. task_embedding_id = task["embd_id"]
  429. task_language = task["language"]
  430. task_llm_id = task["llm_id"]
  431. task_dataset_id = task["kb_id"]
  432. task_doc_id = task["doc_id"]
  433. task_document_name = task["name"]
  434. task_parser_config = task["parser_config"]
  435. # prepare the progress callback function
  436. progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
  437. # FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
  438. lower_case_doc_engine = settings.DOC_ENGINE.lower()
  439. if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
  440. error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
  441. progress_callback(-1, msg=error_message)
  442. raise Exception(error_message)
  443. try:
  444. task_canceled = TaskService.do_cancel(task_id)
  445. except DoesNotExist:
  446. logging.warning(f"task {task_id} is unknown")
  447. return
  448. if task_canceled:
  449. progress_callback(-1, msg="Task has been canceled.")
  450. return
  451. try:
  452. # bind embedding model
  453. embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
  454. except Exception as e:
  455. error_message = f'Fail to bind embedding model: {str(e)}'
  456. progress_callback(-1, msg=error_message)
  457. logging.exception(error_message)
  458. raise
  459. vts, _ = embedding_model.encode(["ok"])
  460. vector_size = len(vts[0])
  461. init_kb(task, vector_size)
  462. # Either using RAPTOR or Standard chunking methods
  463. if task.get("task_type", "") == "raptor":
  464. try:
  465. # bind LLM for raptor
  466. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  467. # run RAPTOR
  468. chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
  469. except TaskCanceledException:
  470. raise
  471. except Exception as e:
  472. error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
  473. progress_callback(-1, msg=error_message)
  474. logging.exception(error_message)
  475. raise
  476. # Either using graphrag or Standard chunking methods
  477. elif task.get("task_type", "") == "graphrag":
  478. start_ts = timer()
  479. try:
  480. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  481. run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
  482. progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
  483. except TaskCanceledException:
  484. raise
  485. except Exception as e:
  486. error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
  487. progress_callback(-1, msg=error_message)
  488. logging.exception(error_message)
  489. raise
  490. return
  491. elif task.get("task_type", "") == "graph_resolution":
  492. start_ts = timer()
  493. try:
  494. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  495. WithResolution(
  496. task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
  497. progress_callback
  498. )
  499. progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
  500. except TaskCanceledException:
  501. raise
  502. except Exception as e:
  503. error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
  504. progress_callback(-1, msg=error_message)
  505. logging.exception(error_message)
  506. raise
  507. return
  508. elif task.get("task_type", "") == "graph_community":
  509. start_ts = timer()
  510. try:
  511. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  512. WithCommunity(
  513. task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
  514. progress_callback
  515. )
  516. progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
  517. except TaskCanceledException:
  518. raise
  519. except Exception as e:
  520. error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
  521. progress_callback(-1, msg=error_message)
  522. logging.exception(error_message)
  523. raise
  524. return
  525. else:
  526. # Standard chunking methods
  527. start_ts = timer()
  528. chunks = build_chunks(task, progress_callback)
  529. logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
  530. if chunks is None:
  531. return
  532. if not chunks:
  533. progress_callback(1., msg=f"No chunk built from {task_document_name}")
  534. return
  535. # TODO: exception handler
  536. ## set_progress(task["did"], -1, "ERROR: ")
  537. progress_callback(msg="Generate {} chunks".format(len(chunks)))
  538. start_ts = timer()
  539. try:
  540. token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
  541. except Exception as e:
  542. error_message = "Generate embedding error:{}".format(str(e))
  543. progress_callback(-1, error_message)
  544. logging.exception(error_message)
  545. token_count = 0
  546. raise
  547. progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
  548. logging.info(progress_message)
  549. progress_callback(msg=progress_message)
  550. chunk_count = len(set([chunk["id"] for chunk in chunks]))
  551. start_ts = timer()
  552. doc_store_result = ""
  553. es_bulk_size = 4
  554. for b in range(0, len(chunks), es_bulk_size):
  555. doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
  556. task_dataset_id)
  557. if b % 128 == 0:
  558. progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
  559. if doc_store_result:
  560. error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
  561. progress_callback(-1, msg=error_message)
  562. raise Exception(error_message)
  563. chunk_ids = [chunk["id"] for chunk in chunks[:b + es_bulk_size]]
  564. chunk_ids_str = " ".join(chunk_ids)
  565. try:
  566. TaskService.update_chunk_ids(task["id"], chunk_ids_str)
  567. except DoesNotExist:
  568. logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
  569. doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
  570. task_dataset_id)
  571. return
  572. logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
  573. task_to_page, len(chunks),
  574. timer() - start_ts))
  575. DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
  576. time_cost = timer() - start_ts
  577. progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
  578. logging.info(
  579. "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
  580. task_to_page, len(chunks),
  581. token_count, time_cost))
  582. def handle_task():
  583. global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  584. task = collect()
  585. if task:
  586. try:
  587. logging.info(f"handle_task begin for task {json.dumps(task)}")
  588. with mt_lock:
  589. CURRENT_TASK = copy.deepcopy(task)
  590. do_handle_task(task)
  591. with mt_lock:
  592. DONE_TASKS += 1
  593. CURRENT_TASK = None
  594. logging.info(f"handle_task done for task {json.dumps(task)}")
  595. except TaskCanceledException:
  596. with mt_lock:
  597. DONE_TASKS += 1
  598. CURRENT_TASK = None
  599. try:
  600. set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException")
  601. except Exception:
  602. pass
  603. logging.debug("handle_task got TaskCanceledException", exc_info=True)
  604. except Exception as e:
  605. with mt_lock:
  606. FAILED_TASKS += 1
  607. CURRENT_TASK = None
  608. try:
  609. set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
  610. except Exception:
  611. pass
  612. logging.exception(f"handle_task got exception for task {json.dumps(task)}")
  613. if PAYLOAD:
  614. PAYLOAD.ack()
  615. PAYLOAD = None
  616. def report_status():
  617. global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
  618. REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
  619. while True:
  620. try:
  621. now = datetime.now()
  622. group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  623. if group_info is not None:
  624. PENDING_TASKS = int(group_info.get("pending", 0))
  625. LAG_TASKS = int(group_info.get("lag", 0))
  626. with mt_lock:
  627. heartbeat = json.dumps({
  628. "name": CONSUMER_NAME,
  629. "now": now.astimezone().isoformat(timespec="milliseconds"),
  630. "boot_at": BOOT_AT,
  631. "pending": PENDING_TASKS,
  632. "lag": LAG_TASKS,
  633. "done": DONE_TASKS,
  634. "failed": FAILED_TASKS,
  635. "current": CURRENT_TASK,
  636. })
  637. REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
  638. logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
  639. expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
  640. if expired > 0:
  641. REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
  642. except Exception:
  643. logging.exception("report_status got exception")
  644. time.sleep(30)
  645. def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
  646. msg = ""
  647. if dump_full:
  648. stats2 = snapshot2.statistics('lineno')
  649. msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
  650. for stat in stats2[:10]:
  651. msg += f"{stat}\n"
  652. stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
  653. msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
  654. for stat in stats1_vs_2[:10]:
  655. msg += f"{stat}\n"
  656. msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
  657. for stat in stats1_vs_2[:3]:
  658. msg += '\n'.join(stat.traceback.format())
  659. logging.info(msg)
  660. def main():
  661. logging.info(r"""
  662. ______ __ ______ __
  663. /_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
  664. / / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/
  665. / / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
  666. /_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
  667. """)
  668. logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
  669. settings.init_settings()
  670. print_rag_settings()
  671. signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
  672. signal.signal(signal.SIGUSR2, stop_tracemalloc)
  673. TRACE_MALLOC_ENABLED = int(os.environ.get('TRACE_MALLOC_ENABLED', "0"))
  674. if TRACE_MALLOC_ENABLED:
  675. start_tracemalloc_and_snapshot(None, None)
  676. # Create an event to signal the background thread to exit
  677. stop_event = threading.Event()
  678. background_thread = threading.Thread(target=report_status)
  679. background_thread.daemon = True
  680. background_thread.start()
  681. # Handle SIGINT (Ctrl+C)
  682. def signal_handler(sig, frame):
  683. logging.info("Received Ctrl+C, shutting down gracefully...")
  684. stop_event.set()
  685. # Give the background thread time to clean up
  686. if background_thread.is_alive():
  687. background_thread.join(timeout=5)
  688. logging.info("Exiting...")
  689. sys.exit(0)
  690. signal.signal(signal.SIGINT, signal_handler)
  691. try:
  692. while not stop_event.is_set():
  693. handle_task()
  694. except KeyboardInterrupt:
  695. logging.info("Interrupted by keyboard, shutting down...")
  696. stop_event.set()
  697. if background_thread.is_alive():
  698. background_thread.join(timeout=5)
  699. if __name__ == "__main__":
  700. main()