You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_executor.py 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # from beartype import BeartypeConf
  16. # from beartype.claw import beartype_all # <-- you didn't sign up for this
  17. # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
  18. import random
  19. import sys
  20. from api.utils.log_utils import initRootLogger, get_project_base_directory
  21. from graphrag.general.index import WithCommunity, WithResolution, Dealer
  22. from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
  23. from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
  24. from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
  25. from rag.prompts import keyword_extraction, question_proposal, content_tagging
  26. CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
  27. CONSUMER_NAME = "task_executor_" + CONSUMER_NO
  28. initRootLogger(CONSUMER_NAME)
  29. import logging
  30. import os
  31. from datetime import datetime
  32. import json
  33. import xxhash
  34. import copy
  35. import re
  36. from functools import partial
  37. from io import BytesIO
  38. from multiprocessing.context import TimeoutError
  39. from timeit import default_timer as timer
  40. import tracemalloc
  41. import resource
  42. import signal
  43. import trio
  44. import numpy as np
  45. from peewee import DoesNotExist
  46. from api.db import LLMType, ParserType, TaskStatus
  47. from api.db.services.document_service import DocumentService
  48. from api.db.services.llm_service import LLMBundle
  49. from api.db.services.task_service import TaskService
  50. from api.db.services.file2document_service import File2DocumentService
  51. from api import settings
  52. from api.versions import get_ragflow_version
  53. from api.db.db_models import close_connection
  54. from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
  55. email, tag
  56. from rag.nlp import search, rag_tokenizer
  57. from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
  58. from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
  59. from rag.utils import num_tokens_from_string
  60. from rag.utils.redis_conn import REDIS_CONN
  61. from rag.utils.storage_factory import STORAGE_IMPL
  62. from graphrag.utils import chat_limiter
  63. BATCH_SIZE = 64
  64. FACTORY = {
  65. "general": naive,
  66. ParserType.NAIVE.value: naive,
  67. ParserType.PAPER.value: paper,
  68. ParserType.BOOK.value: book,
  69. ParserType.PRESENTATION.value: presentation,
  70. ParserType.MANUAL.value: manual,
  71. ParserType.LAWS.value: laws,
  72. ParserType.QA.value: qa,
  73. ParserType.TABLE.value: table,
  74. ParserType.RESUME.value: resume,
  75. ParserType.PICTURE.value: picture,
  76. ParserType.ONE.value: one,
  77. ParserType.AUDIO.value: audio,
  78. ParserType.EMAIL.value: email,
  79. ParserType.KG.value: naive,
  80. ParserType.TAG.value: tag
  81. }
  82. UNACKED_ITERATOR = None
  83. CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
  84. BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
  85. PENDING_TASKS = 0
  86. LAG_TASKS = 0
  87. DONE_TASKS = 0
  88. FAILED_TASKS = 0
  89. CURRENT_TASKS = {}
  90. MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
  91. MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
  92. task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
  93. chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
  94. # SIGUSR1 handler: start tracemalloc and take snapshot
  95. def start_tracemalloc_and_snapshot(signum, frame):
  96. if not tracemalloc.is_tracing():
  97. logging.info("start tracemalloc")
  98. tracemalloc.start()
  99. else:
  100. logging.info("tracemalloc is already running")
  101. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  102. snapshot_file = f"snapshot_{timestamp}.trace"
  103. snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
  104. snapshot = tracemalloc.take_snapshot()
  105. snapshot.dump(snapshot_file)
  106. current, peak = tracemalloc.get_traced_memory()
  107. max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
  108. logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
  109. # SIGUSR2 handler: stop tracemalloc
  110. def stop_tracemalloc(signum, frame):
  111. if tracemalloc.is_tracing():
  112. logging.info("stop tracemalloc")
  113. tracemalloc.stop()
  114. else:
  115. logging.info("tracemalloc not running")
  116. class TaskCanceledException(Exception):
  117. def __init__(self, msg):
  118. self.msg = msg
  119. def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
  120. if prog is not None and prog < 0:
  121. msg = "[ERROR]" + msg
  122. cancel = TaskService.do_cancel(task_id)
  123. if cancel:
  124. msg += " [Canceled]"
  125. prog = -1
  126. if to_page > 0:
  127. if msg:
  128. if from_page < to_page:
  129. msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
  130. if msg:
  131. msg = datetime.now().strftime("%H:%M:%S") + " " + msg
  132. d = {"progress_msg": msg}
  133. if prog is not None:
  134. d["progress"] = prog
  135. logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
  136. TaskService.update_progress(task_id, d)
  137. close_connection()
  138. if cancel:
  139. raise TaskCanceledException(msg)
  140. async def collect():
  141. global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
  142. global UNACKED_ITERATOR
  143. try:
  144. if not UNACKED_ITERATOR:
  145. UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
  146. try:
  147. redis_msg = next(UNACKED_ITERATOR)
  148. except StopIteration:
  149. redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
  150. if not redis_msg:
  151. await trio.sleep(1)
  152. return None, None
  153. except Exception:
  154. logging.exception("collect got exception")
  155. return None, None
  156. msg = redis_msg.get_message()
  157. if not msg:
  158. logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
  159. redis_msg.ack()
  160. return None, None
  161. canceled = False
  162. task = TaskService.get_task(msg["id"])
  163. if task:
  164. _, doc = DocumentService.get_by_id(task["doc_id"])
  165. canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  166. if not task or canceled:
  167. state = "is unknown" if not task else "has been cancelled"
  168. FAILED_TASKS += 1
  169. logging.warning(f"collect task {msg['id']} {state}")
  170. redis_msg.ack()
  171. return None
  172. task["task_type"] = msg.get("task_type", "")
  173. return redis_msg, task
  174. async def get_storage_binary(bucket, name):
  175. return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))
  176. async def build_chunks(task, progress_callback):
  177. if task["size"] > DOC_MAXIMUM_SIZE:
  178. set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
  179. (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
  180. return []
  181. chunker = FACTORY[task["parser_id"].lower()]
  182. try:
  183. st = timer()
  184. bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
  185. binary = await get_storage_binary(bucket, name)
  186. logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
  187. except TimeoutError:
  188. progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
  189. logging.exception(
  190. "Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
  191. raise
  192. except Exception as e:
  193. if re.search("(No such file|not found)", str(e)):
  194. progress_callback(-1, "Can not find file <%s> from minio. Could you try it again?" % task["name"])
  195. else:
  196. progress_callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
  197. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  198. raise
  199. try:
  200. async with chunk_limiter:
  201. cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
  202. to_page=task["to_page"], lang=task["language"], callback=progress_callback,
  203. kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
  204. logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
  205. except TaskCanceledException:
  206. raise
  207. except Exception as e:
  208. progress_callback(-1, "Internal server error while chunking: %s" % str(e).replace("'", ""))
  209. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  210. raise
  211. docs = []
  212. doc = {
  213. "doc_id": task["doc_id"],
  214. "kb_id": str(task["kb_id"])
  215. }
  216. if task["pagerank"]:
  217. doc[PAGERANK_FLD] = int(task["pagerank"])
  218. el = 0
  219. for ck in cks:
  220. d = copy.deepcopy(doc)
  221. d.update(ck)
  222. d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
  223. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  224. d["create_timestamp_flt"] = datetime.now().timestamp()
  225. if not d.get("image"):
  226. _ = d.pop("image", None)
  227. d["img_id"] = ""
  228. docs.append(d)
  229. continue
  230. try:
  231. output_buffer = BytesIO()
  232. if isinstance(d["image"], bytes):
  233. output_buffer = BytesIO(d["image"])
  234. else:
  235. d["image"].save(output_buffer, format='JPEG')
  236. st = timer()
  237. await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()))
  238. el += timer() - st
  239. except Exception:
  240. logging.exception(
  241. "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
  242. raise
  243. d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
  244. del d["image"]
  245. docs.append(d)
  246. logging.info("MINIO PUT({}):{}".format(task["name"], el))
  247. if task["parser_config"].get("auto_keywords", 0):
  248. st = timer()
  249. progress_callback(msg="Start to generate keywords for every chunk ...")
  250. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  251. async def doc_keyword_extraction(chat_mdl, d, topn):
  252. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
  253. if not cached:
  254. async with chat_limiter:
  255. cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
  256. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
  257. if cached:
  258. d["important_kwd"] = cached.split(",")
  259. d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
  260. return
  261. async with trio.open_nursery() as nursery:
  262. for d in docs:
  263. nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
  264. progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
  265. if task["parser_config"].get("auto_questions", 0):
  266. st = timer()
  267. progress_callback(msg="Start to generate questions for every chunk ...")
  268. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  269. async def doc_question_proposal(chat_mdl, d, topn):
  270. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
  271. if not cached:
  272. async with chat_limiter:
  273. cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
  274. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
  275. if cached:
  276. d["question_kwd"] = cached.split("\n")
  277. d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
  278. async with trio.open_nursery() as nursery:
  279. for d in docs:
  280. nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
  281. progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
  282. if task["kb_parser_config"].get("tag_kb_ids", []):
  283. progress_callback(msg="Start to tag for every chunk ...")
  284. kb_ids = task["kb_parser_config"]["tag_kb_ids"]
  285. tenant_id = task["tenant_id"]
  286. topn_tags = task["kb_parser_config"].get("topn_tags", 3)
  287. S = 1000
  288. st = timer()
  289. examples = []
  290. all_tags = get_tags_from_cache(kb_ids)
  291. if not all_tags:
  292. all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
  293. set_tags_to_cache(kb_ids, all_tags)
  294. else:
  295. all_tags = json.loads(all_tags)
  296. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  297. docs_to_tag = []
  298. for d in docs:
  299. if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S):
  300. examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
  301. else:
  302. docs_to_tag.append(d)
  303. async def doc_content_tagging(chat_mdl, d, topn_tags):
  304. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
  305. if not cached:
  306. picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
  307. async with chat_limiter:
  308. cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
  309. if cached:
  310. cached = json.dumps(cached)
  311. if cached:
  312. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
  313. d[TAG_FLD] = json.loads(cached)
  314. async with trio.open_nursery() as nursery:
  315. for d in docs_to_tag:
  316. nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
  317. progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
  318. return docs
  319. def init_kb(row, vector_size: int):
  320. idxnm = search.index_name(row["tenant_id"])
  321. return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
  322. async def embedding(docs, mdl, parser_config=None, callback=None):
  323. if parser_config is None:
  324. parser_config = {}
  325. batch_size = 16
  326. tts, cnts = [], []
  327. for d in docs:
  328. tts.append(d.get("docnm_kwd", "Title"))
  329. c = "\n".join(d.get("question_kwd", []))
  330. if not c:
  331. c = d["content_with_weight"]
  332. c = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c)
  333. if not c:
  334. c = "None"
  335. cnts.append(c)
  336. tk_count = 0
  337. if len(tts) == len(cnts):
  338. vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
  339. tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
  340. tk_count += c
  341. cnts_ = np.array([])
  342. for i in range(0, len(cnts), batch_size):
  343. vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(cnts[i: i + batch_size]))
  344. if len(cnts_) == 0:
  345. cnts_ = vts
  346. else:
  347. cnts_ = np.concatenate((cnts_, vts), axis=0)
  348. tk_count += c
  349. callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
  350. cnts = cnts_
  351. title_w = float(parser_config.get("filename_embd_weight", 0.1))
  352. vects = (title_w * tts + (1 - title_w) *
  353. cnts) if len(tts) == len(cnts) else cnts
  354. assert len(vects) == len(docs)
  355. vector_size = 0
  356. for i, d in enumerate(docs):
  357. v = vects[i].tolist()
  358. vector_size = len(v)
  359. d["q_%d_vec" % len(v)] = v
  360. return tk_count, vector_size
  361. async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
  362. chunks = []
  363. vctr_nm = "q_%d_vec"%vector_size
  364. for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
  365. fields=["content_with_weight", vctr_nm]):
  366. chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
  367. raptor = Raptor(
  368. row["parser_config"]["raptor"].get("max_cluster", 64),
  369. chat_mdl,
  370. embd_mdl,
  371. row["parser_config"]["raptor"]["prompt"],
  372. row["parser_config"]["raptor"]["max_token"],
  373. row["parser_config"]["raptor"]["threshold"]
  374. )
  375. original_length = len(chunks)
  376. chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
  377. doc = {
  378. "doc_id": row["doc_id"],
  379. "kb_id": [str(row["kb_id"])],
  380. "docnm_kwd": row["name"],
  381. "title_tks": rag_tokenizer.tokenize(row["name"])
  382. }
  383. if row["pagerank"]:
  384. doc[PAGERANK_FLD] = int(row["pagerank"])
  385. res = []
  386. tk_count = 0
  387. for content, vctr in chunks[original_length:]:
  388. d = copy.deepcopy(doc)
  389. d["id"] = xxhash.xxh64((content + str(d["doc_id"])).encode("utf-8")).hexdigest()
  390. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  391. d["create_timestamp_flt"] = datetime.now().timestamp()
  392. d[vctr_nm] = vctr.tolist()
  393. d["content_with_weight"] = content
  394. d["content_ltks"] = rag_tokenizer.tokenize(content)
  395. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  396. res.append(d)
  397. tk_count += num_tokens_from_string(content)
  398. return res, tk_count
  399. async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
  400. chunks = []
  401. for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
  402. fields=["content_with_weight", "doc_id"]):
  403. chunks.append((d["doc_id"], d["content_with_weight"]))
  404. dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
  405. row["tenant_id"],
  406. str(row["kb_id"]),
  407. chat_model,
  408. chunks=chunks,
  409. language=language,
  410. entity_types=row["parser_config"]["graphrag"]["entity_types"],
  411. embed_bdl=embedding_model,
  412. callback=callback)
  413. await dealer()
  414. async def do_handle_task(task):
  415. task_id = task["id"]
  416. task_from_page = task["from_page"]
  417. task_to_page = task["to_page"]
  418. task_tenant_id = task["tenant_id"]
  419. task_embedding_id = task["embd_id"]
  420. task_language = task["language"]
  421. task_llm_id = task["llm_id"]
  422. task_dataset_id = task["kb_id"]
  423. task_doc_id = task["doc_id"]
  424. task_document_name = task["name"]
  425. task_parser_config = task["parser_config"]
  426. task_start_ts = timer()
  427. # prepare the progress callback function
  428. progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
  429. # FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
  430. lower_case_doc_engine = settings.DOC_ENGINE.lower()
  431. if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
  432. error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
  433. progress_callback(-1, msg=error_message)
  434. raise Exception(error_message)
  435. task_canceled = TaskService.do_cancel(task_id)
  436. if task_canceled:
  437. progress_callback(-1, msg="Task has been canceled.")
  438. return
  439. try:
  440. # bind embedding model
  441. embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
  442. vts, _ = embedding_model.encode(["ok"])
  443. vector_size = len(vts[0])
  444. except Exception as e:
  445. error_message = f'Fail to bind embedding model: {str(e)}'
  446. progress_callback(-1, msg=error_message)
  447. logging.exception(error_message)
  448. raise
  449. init_kb(task, vector_size)
  450. # Either using RAPTOR or Standard chunking methods
  451. if task.get("task_type", "") == "raptor":
  452. # bind LLM for raptor
  453. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  454. # run RAPTOR
  455. chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
  456. # Either using graphrag or Standard chunking methods
  457. elif task.get("task_type", "") == "graphrag":
  458. start_ts = timer()
  459. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  460. await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
  461. progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
  462. return
  463. elif task.get("task_type", "") == "graph_resolution":
  464. start_ts = timer()
  465. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  466. with_res = WithResolution(
  467. task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
  468. progress_callback
  469. )
  470. await with_res()
  471. progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
  472. return
  473. elif task.get("task_type", "") == "graph_community":
  474. start_ts = timer()
  475. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  476. with_comm = WithCommunity(
  477. task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
  478. progress_callback
  479. )
  480. await with_comm()
  481. progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
  482. return
  483. else:
  484. # Standard chunking methods
  485. start_ts = timer()
  486. chunks = await build_chunks(task, progress_callback)
  487. logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
  488. if chunks is None:
  489. return
  490. if not chunks:
  491. progress_callback(1., msg=f"No chunk built from {task_document_name}")
  492. return
  493. # TODO: exception handler
  494. ## set_progress(task["did"], -1, "ERROR: ")
  495. progress_callback(msg="Generate {} chunks".format(len(chunks)))
  496. start_ts = timer()
  497. try:
  498. token_count, vector_size = await embedding(chunks, embedding_model, task_parser_config, progress_callback)
  499. except Exception as e:
  500. error_message = "Generate embedding error:{}".format(str(e))
  501. progress_callback(-1, error_message)
  502. logging.exception(error_message)
  503. token_count = 0
  504. raise
  505. progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
  506. logging.info(progress_message)
  507. progress_callback(msg=progress_message)
  508. chunk_count = len(set([chunk["id"] for chunk in chunks]))
  509. start_ts = timer()
  510. doc_store_result = ""
  511. es_bulk_size = 4
  512. for b in range(0, len(chunks), es_bulk_size):
  513. doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id))
  514. if b % 128 == 0:
  515. progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
  516. if doc_store_result:
  517. error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
  518. progress_callback(-1, msg=error_message)
  519. raise Exception(error_message)
  520. chunk_ids = [chunk["id"] for chunk in chunks[:b + es_bulk_size]]
  521. chunk_ids_str = " ".join(chunk_ids)
  522. try:
  523. TaskService.update_chunk_ids(task["id"], chunk_ids_str)
  524. except DoesNotExist:
  525. logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
  526. doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
  527. return
  528. logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
  529. task_to_page, len(chunks),
  530. timer() - start_ts))
  531. DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
  532. time_cost = timer() - start_ts
  533. task_time_cost = timer() - task_start_ts
  534. progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
  535. logging.info(
  536. "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
  537. task_to_page, len(chunks),
  538. token_count, task_time_cost))
  539. async def handle_task():
  540. global DONE_TASKS, FAILED_TASKS
  541. redis_msg, task = await collect()
  542. if not task:
  543. return
  544. try:
  545. logging.info(f"handle_task begin for task {json.dumps(task)}")
  546. CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
  547. await do_handle_task(task)
  548. DONE_TASKS += 1
  549. CURRENT_TASKS.pop(task["id"], None)
  550. logging.info(f"handle_task done for task {json.dumps(task)}")
  551. except Exception as e:
  552. FAILED_TASKS += 1
  553. CURRENT_TASKS.pop(task["id"], None)
  554. try:
  555. set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
  556. except Exception:
  557. pass
  558. logging.exception(f"handle_task got exception for task {json.dumps(task)}")
  559. redis_msg.ack()
  560. async def report_status():
  561. global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS
  562. REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
  563. while True:
  564. try:
  565. now = datetime.now()
  566. group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
  567. if group_info is not None:
  568. PENDING_TASKS = int(group_info.get("pending", 0))
  569. LAG_TASKS = int(group_info.get("lag", 0))
  570. current = copy.deepcopy(CURRENT_TASKS)
  571. heartbeat = json.dumps({
  572. "name": CONSUMER_NAME,
  573. "now": now.astimezone().isoformat(timespec="milliseconds"),
  574. "boot_at": BOOT_AT,
  575. "pending": PENDING_TASKS,
  576. "lag": LAG_TASKS,
  577. "done": DONE_TASKS,
  578. "failed": FAILED_TASKS,
  579. "current": current,
  580. })
  581. REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
  582. logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
  583. expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
  584. if expired > 0:
  585. REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
  586. except Exception:
  587. logging.exception("report_status got exception")
  588. await trio.sleep(30)
  589. async def main():
  590. logging.info(r"""
  591. ______ __ ______ __
  592. /_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
  593. / / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/
  594. / / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
  595. /_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
  596. """)
  597. logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
  598. settings.init_settings()
  599. print_rag_settings()
  600. signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
  601. signal.signal(signal.SIGUSR2, stop_tracemalloc)
  602. TRACE_MALLOC_ENABLED = int(os.environ.get('TRACE_MALLOC_ENABLED', "0"))
  603. if TRACE_MALLOC_ENABLED:
  604. start_tracemalloc_and_snapshot(None, None)
  605. async with trio.open_nursery() as nursery:
  606. nursery.start_soon(report_status)
  607. while True:
  608. async with task_limiter:
  609. nursery.start_soon(handle_task)
  610. logging.error("BUG!!! You should not reach here!!!")
  611. if __name__ == "__main__":
  612. trio.run(main)