Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

task_executor.py 31KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # from beartype import BeartypeConf
  16. # from beartype.claw import beartype_all # <-- you didn't sign up for this
  17. # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
  18. import random
  19. import sys
  20. import threading
  21. import time
  22. from api.utils.log_utils import init_root_logger, get_project_base_directory
  23. from graphrag.general.index import run_graphrag
  24. from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
  25. from rag.prompts import keyword_extraction, question_proposal, content_tagging
  26. import logging
  27. import os
  28. from datetime import datetime
  29. import json
  30. import xxhash
  31. import copy
  32. import re
  33. from functools import partial
  34. from io import BytesIO
  35. from multiprocessing.context import TimeoutError
  36. from timeit import default_timer as timer
  37. import tracemalloc
  38. import signal
  39. import trio
  40. import exceptiongroup
  41. import faulthandler
  42. import numpy as np
  43. from peewee import DoesNotExist
  44. from api.db import LLMType, ParserType
  45. from api.db.services.document_service import DocumentService
  46. from api.db.services.llm_service import LLMBundle
  47. from api.db.services.task_service import TaskService
  48. from api.db.services.file2document_service import File2DocumentService
  49. from api import settings
  50. from api.versions import get_ragflow_version
  51. from api.db.db_models import close_connection
  52. from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
  53. email, tag
  54. from rag.nlp import search, rag_tokenizer
  55. from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
  56. from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
  57. from rag.utils import num_tokens_from_string, truncate
  58. from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
  59. from rag.utils.storage_factory import STORAGE_IMPL
  60. from graphrag.utils import chat_limiter
  61. BATCH_SIZE = 64
  62. FACTORY = {
  63. "general": naive,
  64. ParserType.NAIVE.value: naive,
  65. ParserType.PAPER.value: paper,
  66. ParserType.BOOK.value: book,
  67. ParserType.PRESENTATION.value: presentation,
  68. ParserType.MANUAL.value: manual,
  69. ParserType.LAWS.value: laws,
  70. ParserType.QA.value: qa,
  71. ParserType.TABLE.value: table,
  72. ParserType.RESUME.value: resume,
  73. ParserType.PICTURE.value: picture,
  74. ParserType.ONE.value: one,
  75. ParserType.AUDIO.value: audio,
  76. ParserType.EMAIL.value: email,
  77. ParserType.KG.value: naive,
  78. ParserType.TAG.value: tag
  79. }
  80. UNACKED_ITERATOR = None
  81. CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
  82. CONSUMER_NAME = "task_executor_" + CONSUMER_NO
  83. BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
  84. PENDING_TASKS = 0
  85. LAG_TASKS = 0
  86. DONE_TASKS = 0
  87. FAILED_TASKS = 0
  88. CURRENT_TASKS = {}
  89. MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
  90. MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
  91. MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10'))
  92. task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
  93. chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
  94. minio_limiter = trio.CapacityLimiter(MAX_CONCURRENT_MINIO)
  95. kg_limiter = trio.CapacityLimiter(2)
  96. WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120'))
  97. stop_event = threading.Event()
  98. def signal_handler(sig, frame):
  99. logging.info("Received interrupt signal, shutting down...")
  100. stop_event.set()
  101. time.sleep(1)
  102. sys.exit(0)
  103. # SIGUSR1 handler: start tracemalloc and take snapshot
  104. def start_tracemalloc_and_snapshot(signum, frame):
  105. if not tracemalloc.is_tracing():
  106. logging.info("start tracemalloc")
  107. tracemalloc.start()
  108. else:
  109. logging.info("tracemalloc is already running")
  110. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  111. snapshot_file = f"snapshot_{timestamp}.trace"
  112. snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
  113. snapshot = tracemalloc.take_snapshot()
  114. snapshot.dump(snapshot_file)
  115. current, peak = tracemalloc.get_traced_memory()
  116. if sys.platform == "win32":
  117. import psutil
  118. process = psutil.Process()
  119. max_rss = process.memory_info().rss / 1024
  120. else:
  121. import resource
  122. max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
  123. logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
  124. # SIGUSR2 handler: stop tracemalloc
  125. def stop_tracemalloc(signum, frame):
  126. if tracemalloc.is_tracing():
  127. logging.info("stop tracemalloc")
  128. tracemalloc.stop()
  129. else:
  130. logging.info("tracemalloc not running")
  131. class TaskCanceledException(Exception):
  132. def __init__(self, msg):
  133. self.msg = msg
  134. def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
  135. try:
  136. if prog is not None and prog < 0:
  137. msg = "[ERROR]" + msg
  138. cancel = TaskService.do_cancel(task_id)
  139. if cancel:
  140. msg += " [Canceled]"
  141. prog = -1
  142. if to_page > 0:
  143. if msg:
  144. if from_page < to_page:
  145. msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
  146. if msg:
  147. msg = datetime.now().strftime("%H:%M:%S") + " " + msg
  148. d = {"progress_msg": msg}
  149. if prog is not None:
  150. d["progress"] = prog
  151. TaskService.update_progress(task_id, d)
  152. close_connection()
  153. if cancel:
  154. raise TaskCanceledException(msg)
  155. logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
  156. except DoesNotExist:
  157. logging.warning(f"set_progress({task_id}) got exception DoesNotExist")
  158. except Exception:
  159. logging.exception(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}, got exception")
  160. async def collect():
  161. global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
  162. global UNACKED_ITERATOR
  163. svr_queue_names = get_svr_queue_names()
  164. try:
  165. if not UNACKED_ITERATOR:
  166. UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
  167. try:
  168. redis_msg = next(UNACKED_ITERATOR)
  169. except StopIteration:
  170. for svr_queue_name in svr_queue_names:
  171. redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
  172. if redis_msg:
  173. break
  174. except Exception:
  175. logging.exception("collect got exception")
  176. return None, None
  177. if not redis_msg:
  178. return None, None
  179. msg = redis_msg.get_message()
  180. if not msg:
  181. logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
  182. redis_msg.ack()
  183. return None, None
  184. canceled = False
  185. task = TaskService.get_task(msg["id"])
  186. if task:
  187. canceled = TaskService.do_cancel(task["id"])
  188. if not task or canceled:
  189. state = "is unknown" if not task else "has been cancelled"
  190. FAILED_TASKS += 1
  191. logging.warning(f"collect task {msg['id']} {state}")
  192. redis_msg.ack()
  193. return None, None
  194. task["task_type"] = msg.get("task_type", "")
  195. return redis_msg, task
  196. async def get_storage_binary(bucket, name):
  197. return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))
  198. async def build_chunks(task, progress_callback):
  199. if task["size"] > DOC_MAXIMUM_SIZE:
  200. set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
  201. (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
  202. return []
  203. chunker = FACTORY[task["parser_id"].lower()]
  204. try:
  205. st = timer()
  206. bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
  207. binary = await get_storage_binary(bucket, name)
  208. logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
  209. except TimeoutError:
  210. progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
  211. logging.exception(
  212. "Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
  213. raise
  214. except Exception as e:
  215. if re.search("(No such file|not found)", str(e)):
  216. progress_callback(-1, "Can not find file <%s> from minio. Could you try it again?" % task["name"])
  217. else:
  218. progress_callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
  219. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  220. raise
  221. try:
  222. async with chunk_limiter:
  223. cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
  224. to_page=task["to_page"], lang=task["language"], callback=progress_callback,
  225. kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
  226. logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
  227. except TaskCanceledException:
  228. raise
  229. except Exception as e:
  230. progress_callback(-1, "Internal server error while chunking: %s" % str(e).replace("'", ""))
  231. logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
  232. raise
  233. docs = []
  234. doc = {
  235. "doc_id": task["doc_id"],
  236. "kb_id": str(task["kb_id"])
  237. }
  238. if task["pagerank"]:
  239. doc[PAGERANK_FLD] = int(task["pagerank"])
  240. st = timer()
  241. async def upload_to_minio(document, chunk):
  242. try:
  243. d = copy.deepcopy(document)
  244. d.update(chunk)
  245. d["id"] = xxhash.xxh64((chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
  246. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  247. d["create_timestamp_flt"] = datetime.now().timestamp()
  248. if not d.get("image"):
  249. _ = d.pop("image", None)
  250. d["img_id"] = ""
  251. docs.append(d)
  252. return
  253. output_buffer = BytesIO()
  254. if isinstance(d["image"], bytes):
  255. output_buffer = BytesIO(d["image"])
  256. else:
  257. # If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format.
  258. if d["image"].mode in ("RGBA", "P"):
  259. d["image"] = d["image"].convert("RGB")
  260. d["image"].save(output_buffer, format='JPEG')
  261. async with minio_limiter:
  262. await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()))
  263. d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
  264. del d["image"]
  265. docs.append(d)
  266. except Exception:
  267. logging.exception(
  268. "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
  269. raise
  270. async with trio.open_nursery() as nursery:
  271. for ck in cks:
  272. nursery.start_soon(upload_to_minio, doc, ck)
  273. el = timer() - st
  274. logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el))
  275. if task["parser_config"].get("auto_keywords", 0):
  276. st = timer()
  277. progress_callback(msg="Start to generate keywords for every chunk ...")
  278. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  279. async def doc_keyword_extraction(chat_mdl, d, topn):
  280. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
  281. if not cached:
  282. async with chat_limiter:
  283. cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
  284. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
  285. if cached:
  286. d["important_kwd"] = cached.split(",")
  287. d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
  288. return
  289. async with trio.open_nursery() as nursery:
  290. for d in docs:
  291. nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
  292. progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
  293. if task["parser_config"].get("auto_questions", 0):
  294. st = timer()
  295. progress_callback(msg="Start to generate questions for every chunk ...")
  296. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  297. async def doc_question_proposal(chat_mdl, d, topn):
  298. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
  299. if not cached:
  300. async with chat_limiter:
  301. cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
  302. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
  303. if cached:
  304. d["question_kwd"] = cached.split("\n")
  305. d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
  306. async with trio.open_nursery() as nursery:
  307. for d in docs:
  308. nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
  309. progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
  310. if task["kb_parser_config"].get("tag_kb_ids", []):
  311. progress_callback(msg="Start to tag for every chunk ...")
  312. kb_ids = task["kb_parser_config"]["tag_kb_ids"]
  313. tenant_id = task["tenant_id"]
  314. topn_tags = task["kb_parser_config"].get("topn_tags", 3)
  315. S = 1000
  316. st = timer()
  317. examples = []
  318. all_tags = get_tags_from_cache(kb_ids)
  319. if not all_tags:
  320. all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
  321. set_tags_to_cache(kb_ids, all_tags)
  322. else:
  323. all_tags = json.loads(all_tags)
  324. chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
  325. docs_to_tag = []
  326. for d in docs:
  327. task_canceled = TaskService.do_cancel(task["id"])
  328. if task_canceled:
  329. progress_callback(-1, msg="Task has been canceled.")
  330. return
  331. if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
  332. examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
  333. else:
  334. docs_to_tag.append(d)
  335. async def doc_content_tagging(chat_mdl, d, topn_tags):
  336. cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
  337. if not cached:
  338. picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
  339. if not picked_examples:
  340. picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
  341. async with chat_limiter:
  342. cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
  343. if cached:
  344. cached = json.dumps(cached)
  345. if cached:
  346. set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
  347. d[TAG_FLD] = json.loads(cached)
  348. async with trio.open_nursery() as nursery:
  349. for d in docs_to_tag:
  350. nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
  351. progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
  352. return docs
  353. def init_kb(row, vector_size: int):
  354. idxnm = search.index_name(row["tenant_id"])
  355. return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
  356. async def embedding(docs, mdl, parser_config=None, callback=None):
  357. if parser_config is None:
  358. parser_config = {}
  359. tts, cnts = [], []
  360. for d in docs:
  361. tts.append(d.get("docnm_kwd", "Title"))
  362. c = "\n".join(d.get("question_kwd", []))
  363. if not c:
  364. c = d["content_with_weight"]
  365. c = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c)
  366. if not c:
  367. c = "None"
  368. cnts.append(c)
  369. tk_count = 0
  370. if len(tts) == len(cnts):
  371. vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
  372. tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
  373. tk_count += c
  374. cnts_ = np.array([])
  375. for i in range(0, len(cnts), EMBEDDING_BATCH_SIZE):
  376. vts, c = await trio.to_thread.run_sync(lambda: mdl.encode([truncate(c, mdl.max_length-10) for c in cnts[i: i + EMBEDDING_BATCH_SIZE]]))
  377. if len(cnts_) == 0:
  378. cnts_ = vts
  379. else:
  380. cnts_ = np.concatenate((cnts_, vts), axis=0)
  381. tk_count += c
  382. callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
  383. cnts = cnts_
  384. title_w = float(parser_config.get("filename_embd_weight", 0.1))
  385. vects = (title_w * tts + (1 - title_w) *
  386. cnts) if len(tts) == len(cnts) else cnts
  387. assert len(vects) == len(docs)
  388. vector_size = 0
  389. for i, d in enumerate(docs):
  390. v = vects[i].tolist()
  391. vector_size = len(v)
  392. d["q_%d_vec" % len(v)] = v
  393. return tk_count, vector_size
  394. async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
  395. chunks = []
  396. vctr_nm = "q_%d_vec"%vector_size
  397. for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
  398. fields=["content_with_weight", vctr_nm]):
  399. chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
  400. raptor = Raptor(
  401. row["parser_config"]["raptor"].get("max_cluster", 64),
  402. chat_mdl,
  403. embd_mdl,
  404. row["parser_config"]["raptor"]["prompt"],
  405. row["parser_config"]["raptor"]["max_token"],
  406. row["parser_config"]["raptor"]["threshold"]
  407. )
  408. original_length = len(chunks)
  409. chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
  410. doc = {
  411. "doc_id": row["doc_id"],
  412. "kb_id": [str(row["kb_id"])],
  413. "docnm_kwd": row["name"],
  414. "title_tks": rag_tokenizer.tokenize(row["name"])
  415. }
  416. if row["pagerank"]:
  417. doc[PAGERANK_FLD] = int(row["pagerank"])
  418. res = []
  419. tk_count = 0
  420. for content, vctr in chunks[original_length:]:
  421. d = copy.deepcopy(doc)
  422. d["id"] = xxhash.xxh64((content + str(d["doc_id"])).encode("utf-8")).hexdigest()
  423. d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
  424. d["create_timestamp_flt"] = datetime.now().timestamp()
  425. d[vctr_nm] = vctr.tolist()
  426. d["content_with_weight"] = content
  427. d["content_ltks"] = rag_tokenizer.tokenize(content)
  428. d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
  429. res.append(d)
  430. tk_count += num_tokens_from_string(content)
  431. return res, tk_count
  432. async def do_handle_task(task):
  433. task_id = task["id"]
  434. task_from_page = task["from_page"]
  435. task_to_page = task["to_page"]
  436. task_tenant_id = task["tenant_id"]
  437. task_embedding_id = task["embd_id"]
  438. task_language = task["language"]
  439. task_llm_id = task["llm_id"]
  440. task_dataset_id = task["kb_id"]
  441. task_doc_id = task["doc_id"]
  442. task_document_name = task["name"]
  443. task_parser_config = task["parser_config"]
  444. task_start_ts = timer()
  445. # prepare the progress callback function
  446. progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
  447. # FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
  448. lower_case_doc_engine = settings.DOC_ENGINE.lower()
  449. if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
  450. error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
  451. progress_callback(-1, msg=error_message)
  452. raise Exception(error_message)
  453. task_canceled = TaskService.do_cancel(task_id)
  454. if task_canceled:
  455. progress_callback(-1, msg="Task has been canceled.")
  456. return
  457. try:
  458. # bind embedding model
  459. embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
  460. vts, _ = embedding_model.encode(["ok"])
  461. vector_size = len(vts[0])
  462. except Exception as e:
  463. error_message = f'Fail to bind embedding model: {str(e)}'
  464. progress_callback(-1, msg=error_message)
  465. logging.exception(error_message)
  466. raise
  467. init_kb(task, vector_size)
  468. # Either using RAPTOR or Standard chunking methods
  469. if task.get("task_type", "") == "raptor":
  470. # bind LLM for raptor
  471. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  472. # run RAPTOR
  473. async with kg_limiter:
  474. chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
  475. # Either using graphrag or Standard chunking methods
  476. elif task.get("task_type", "") == "graphrag":
  477. if not task_parser_config.get("graphrag", {}).get("use_graphrag", False):
  478. progress_callback(prog=-1.0, msg="Internal configuration error.")
  479. return
  480. graphrag_conf = task["kb_parser_config"].get("graphrag", {})
  481. start_ts = timer()
  482. chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
  483. with_resolution = graphrag_conf.get("resolution", False)
  484. with_community = graphrag_conf.get("community", False)
  485. async with kg_limiter:
  486. await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
  487. progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
  488. return
  489. else:
  490. # Standard chunking methods
  491. start_ts = timer()
  492. chunks = await build_chunks(task, progress_callback)
  493. logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
  494. if not chunks:
  495. progress_callback(1., msg=f"No chunk built from {task_document_name}")
  496. return
  497. # TODO: exception handler
  498. ## set_progress(task["did"], -1, "ERROR: ")
  499. progress_callback(msg="Generate {} chunks".format(len(chunks)))
  500. start_ts = timer()
  501. try:
  502. token_count, vector_size = await embedding(chunks, embedding_model, task_parser_config, progress_callback)
  503. except Exception as e:
  504. error_message = "Generate embedding error:{}".format(str(e))
  505. progress_callback(-1, error_message)
  506. logging.exception(error_message)
  507. token_count = 0
  508. raise
  509. progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
  510. logging.info(progress_message)
  511. progress_callback(msg=progress_message)
  512. chunk_count = len(set([chunk["id"] for chunk in chunks]))
  513. start_ts = timer()
  514. doc_store_result = ""
  515. async def delete_image(kb_id, chunk_id):
  516. try:
  517. async with minio_limiter:
  518. STORAGE_IMPL.delete(kb_id, chunk_id)
  519. except Exception:
  520. logging.exception(
  521. "Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id))
  522. raise
  523. for b in range(0, len(chunks), DOC_BULK_SIZE):
  524. doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
  525. task_canceled = TaskService.do_cancel(task_id)
  526. if task_canceled:
  527. progress_callback(-1, msg="Task has been canceled.")
  528. return
  529. if b % 128 == 0:
  530. progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
  531. if doc_store_result:
  532. error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
  533. progress_callback(-1, msg=error_message)
  534. raise Exception(error_message)
  535. chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
  536. chunk_ids_str = " ".join(chunk_ids)
  537. try:
  538. TaskService.update_chunk_ids(task["id"], chunk_ids_str)
  539. except DoesNotExist:
  540. logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
  541. doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
  542. async with trio.open_nursery() as nursery:
  543. for chunk_id in chunk_ids:
  544. nursery.start_soon(delete_image, task_dataset_id, chunk_id)
  545. progress_callback(-1, msg=f"Chunk updates failed since task {task['id']} is unknown.")
  546. return
  547. logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
  548. task_to_page, len(chunks),
  549. timer() - start_ts))
  550. DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
  551. time_cost = timer() - start_ts
  552. task_time_cost = timer() - task_start_ts
  553. progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
  554. logging.info(
  555. "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
  556. task_to_page, len(chunks),
  557. token_count, task_time_cost))
  558. async def handle_task():
  559. global DONE_TASKS, FAILED_TASKS
  560. redis_msg, task = await collect()
  561. if not task:
  562. await trio.sleep(5)
  563. return
  564. try:
  565. logging.info(f"handle_task begin for task {json.dumps(task)}")
  566. CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
  567. await do_handle_task(task)
  568. DONE_TASKS += 1
  569. CURRENT_TASKS.pop(task["id"], None)
  570. logging.info(f"handle_task done for task {json.dumps(task)}")
  571. except Exception as e:
  572. FAILED_TASKS += 1
  573. CURRENT_TASKS.pop(task["id"], None)
  574. try:
  575. err_msg = str(e)
  576. while isinstance(e, exceptiongroup.ExceptionGroup):
  577. e = e.exceptions[0]
  578. err_msg += ' -- ' + str(e)
  579. set_progress(task["id"], prog=-1, msg=f"[Exception]: {err_msg}")
  580. except Exception:
  581. pass
  582. logging.exception(f"handle_task got exception for task {json.dumps(task)}")
  583. redis_msg.ack()
  584. async def report_status():
  585. global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS
  586. REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
  587. redis_lock = RedisDistributedLock("clean_task_executor", lock_value=CONSUMER_NAME, timeout=60)
  588. while True:
  589. try:
  590. now = datetime.now()
  591. group_info = REDIS_CONN.queue_info(get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME)
  592. if group_info is not None:
  593. PENDING_TASKS = int(group_info.get("pending", 0))
  594. LAG_TASKS = int(group_info.get("lag", 0))
  595. current = copy.deepcopy(CURRENT_TASKS)
  596. heartbeat = json.dumps({
  597. "name": CONSUMER_NAME,
  598. "now": now.astimezone().isoformat(timespec="milliseconds"),
  599. "boot_at": BOOT_AT,
  600. "pending": PENDING_TASKS,
  601. "lag": LAG_TASKS,
  602. "done": DONE_TASKS,
  603. "failed": FAILED_TASKS,
  604. "current": current,
  605. })
  606. REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
  607. logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
  608. expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
  609. if expired > 0:
  610. REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
  611. # clean task executor
  612. if redis_lock.acquire():
  613. task_executors = REDIS_CONN.smembers("TASKEXE")
  614. for consumer_name in task_executors:
  615. if consumer_name == CONSUMER_NAME:
  616. continue
  617. expired = REDIS_CONN.zcount(
  618. consumer_name, now.timestamp() - WORKER_HEARTBEAT_TIMEOUT, now.timestamp() + 10
  619. )
  620. if expired == 0:
  621. logging.info(f"{consumer_name} expired, removed")
  622. REDIS_CONN.srem("TASKEXE", consumer_name)
  623. REDIS_CONN.delete(consumer_name)
  624. except Exception:
  625. logging.exception("report_status got exception")
  626. finally:
  627. redis_lock.release()
  628. await trio.sleep(30)
  629. async def task_manager():
  630. try:
  631. await handle_task()
  632. finally:
  633. task_limiter.release()
  634. async def main():
  635. logging.info(r"""
  636. ______ __ ______ __
  637. /_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
  638. / / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/
  639. / / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
  640. /_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
  641. """)
  642. logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
  643. settings.init_settings()
  644. print_rag_settings()
  645. if sys.platform != "win32":
  646. signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
  647. signal.signal(signal.SIGUSR2, stop_tracemalloc)
  648. TRACE_MALLOC_ENABLED = int(os.environ.get('TRACE_MALLOC_ENABLED', "0"))
  649. if TRACE_MALLOC_ENABLED:
  650. start_tracemalloc_and_snapshot(None, None)
  651. signal.signal(signal.SIGINT, signal_handler)
  652. signal.signal(signal.SIGTERM, signal_handler)
  653. async with trio.open_nursery() as nursery:
  654. nursery.start_soon(report_status)
  655. while not stop_event.is_set():
  656. await task_limiter.acquire()
  657. nursery.start_soon(task_manager)
  658. logging.error("BUG!!! You should not reach here!!!")
  659. if __name__ == "__main__":
  660. faulthandler.enable()
  661. init_root_logger(CONSUMER_NAME)
  662. trio.run(main)