Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

task_service.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import random
  19. import xxhash
  20. from datetime import datetime
  21. from api.db.db_utils import bulk_insert_into_db
  22. from deepdoc.parser import PdfParser
  23. from peewee import JOIN
  24. from api.db.db_models import DB, File2Document, File
  25. from api.db import StatusEnum, FileType, TaskStatus
  26. from api.db.db_models import Task, Document, Knowledgebase, Tenant
  27. from api.db.services.common_service import CommonService
  28. from api.db.services.document_service import DocumentService
  29. from api.utils import current_timestamp, get_uuid
  30. from deepdoc.parser.excel_parser import RAGFlowExcelParser
  31. from rag.settings import get_svr_queue_name
  32. from rag.utils.storage_factory import STORAGE_IMPL
  33. from rag.utils.redis_conn import REDIS_CONN
  34. from api import settings
  35. from rag.nlp import search
  36. def trim_header_by_lines(text: str, max_length) -> str:
  37. # Trim header text to maximum length while preserving line breaks
  38. # Args:
  39. # text: Input text to trim
  40. # max_length: Maximum allowed length
  41. # Returns:
  42. # Trimmed text
  43. len_text = len(text)
  44. if len_text <= max_length:
  45. return text
  46. for i in range(len_text):
  47. if text[i] == '\n' and len_text - i <= max_length:
  48. return text[i + 1:]
  49. return text
  50. class TaskService(CommonService):
  51. """Service class for managing document processing tasks.
  52. This class extends CommonService to provide specialized functionality for document
  53. processing task management, including task creation, progress tracking, and chunk
  54. management. It handles various document types (PDF, Excel, etc.) and manages their
  55. processing lifecycle.
  56. The class implements a robust task queue system with retry mechanisms and progress
  57. tracking, supporting both synchronous and asynchronous task execution.
  58. Attributes:
  59. model: The Task model class for database operations.
  60. """
  61. model = Task
  62. @classmethod
  63. @DB.connection_context()
  64. def get_task(cls, task_id):
  65. """Retrieve detailed task information by task ID.
  66. This method fetches comprehensive task details including associated document,
  67. knowledge base, and tenant information. It also handles task retry logic and
  68. progress updates.
  69. Args:
  70. task_id (str): The unique identifier of the task to retrieve.
  71. Returns:
  72. dict: Task details dictionary containing all task information and related metadata.
  73. Returns None if task is not found or has exceeded retry limit.
  74. """
  75. fields = [
  76. cls.model.id,
  77. cls.model.doc_id,
  78. cls.model.from_page,
  79. cls.model.to_page,
  80. cls.model.retry_count,
  81. Document.kb_id,
  82. Document.parser_id,
  83. Document.parser_config,
  84. Document.name,
  85. Document.type,
  86. Document.location,
  87. Document.size,
  88. Knowledgebase.tenant_id,
  89. Knowledgebase.language,
  90. Knowledgebase.embd_id,
  91. Knowledgebase.pagerank,
  92. Knowledgebase.parser_config.alias("kb_parser_config"),
  93. Tenant.img2txt_id,
  94. Tenant.asr_id,
  95. Tenant.llm_id,
  96. cls.model.update_time,
  97. ]
  98. docs = (
  99. cls.model.select(*fields)
  100. .join(Document, on=(cls.model.doc_id == Document.id))
  101. .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
  102. .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
  103. .where(cls.model.id == task_id)
  104. )
  105. docs = list(docs.dicts())
  106. if not docs:
  107. return None
  108. msg = f"\n{datetime.now().strftime('%H:%M:%S')} Task has been received."
  109. prog = random.random() / 10.0
  110. if docs[0]["retry_count"] >= 3:
  111. msg = "\nERROR: Task is abandoned after 3 times attempts."
  112. prog = -1
  113. cls.model.update(
  114. progress_msg=cls.model.progress_msg + msg,
  115. progress=prog,
  116. retry_count=docs[0]["retry_count"] + 1,
  117. ).where(cls.model.id == docs[0]["id"]).execute()
  118. if docs[0]["retry_count"] >= 3:
  119. return None
  120. return docs[0]
  121. @classmethod
  122. @DB.connection_context()
  123. def get_tasks(cls, doc_id: str):
  124. """Retrieve all tasks associated with a document.
  125. This method fetches all processing tasks for a given document, ordered by page
  126. number and creation time. It includes task progress and chunk information.
  127. Args:
  128. doc_id (str): The unique identifier of the document.
  129. Returns:
  130. list[dict]: List of task dictionaries containing task details.
  131. Returns None if no tasks are found.
  132. """
  133. fields = [
  134. cls.model.id,
  135. cls.model.from_page,
  136. cls.model.progress,
  137. cls.model.digest,
  138. cls.model.chunk_ids,
  139. ]
  140. tasks = (
  141. cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
  142. .where(cls.model.doc_id == doc_id)
  143. )
  144. tasks = list(tasks.dicts())
  145. if not tasks:
  146. return None
  147. return tasks
  148. @classmethod
  149. @DB.connection_context()
  150. def update_chunk_ids(cls, id: str, chunk_ids: str):
  151. """Update the chunk IDs associated with a task.
  152. This method updates the chunk_ids field of a task, which stores the IDs of
  153. processed document chunks in a space-separated string format.
  154. Args:
  155. id (str): The unique identifier of the task.
  156. chunk_ids (str): Space-separated string of chunk identifiers.
  157. """
  158. cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()
  159. @classmethod
  160. @DB.connection_context()
  161. def get_ongoing_doc_name(cls):
  162. """Get names of documents that are currently being processed.
  163. This method retrieves information about documents that are in the processing state,
  164. including their locations and associated IDs. It uses database locking to ensure
  165. thread safety when accessing the task information.
  166. Returns:
  167. list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
  168. for documents currently being processed. Returns empty list if
  169. no documents are being processed.
  170. """
  171. with DB.lock("get_task", -1):
  172. docs = (
  173. cls.model.select(
  174. *[Document.id, Document.kb_id, Document.location, File.parent_id]
  175. )
  176. .join(Document, on=(cls.model.doc_id == Document.id))
  177. .join(
  178. File2Document,
  179. on=(File2Document.document_id == Document.id),
  180. join_type=JOIN.LEFT_OUTER,
  181. )
  182. .join(
  183. File,
  184. on=(File2Document.file_id == File.id),
  185. join_type=JOIN.LEFT_OUTER,
  186. )
  187. .where(
  188. Document.status == StatusEnum.VALID.value,
  189. Document.run == TaskStatus.RUNNING.value,
  190. ~(Document.type == FileType.VIRTUAL.value),
  191. cls.model.progress < 1,
  192. cls.model.create_time >= current_timestamp() - 1000 * 600,
  193. )
  194. )
  195. docs = list(docs.dicts())
  196. if not docs:
  197. return []
  198. return list(
  199. set(
  200. [
  201. (
  202. d["parent_id"] if d["parent_id"] else d["kb_id"],
  203. d["location"],
  204. )
  205. for d in docs
  206. ]
  207. )
  208. )
  209. @classmethod
  210. @DB.connection_context()
  211. def do_cancel(cls, id):
  212. """Check if a task should be cancelled based on its document status.
  213. This method determines whether a task should be cancelled by checking the
  214. associated document's run status and progress. A task should be cancelled
  215. if its document is marked for cancellation or has negative progress.
  216. Args:
  217. id (str): The unique identifier of the task to check.
  218. Returns:
  219. bool: True if the task should be cancelled, False otherwise.
  220. """
  221. task = cls.model.get_by_id(id)
  222. _, doc = DocumentService.get_by_id(task.doc_id)
  223. return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
  224. @classmethod
  225. @DB.connection_context()
  226. def update_progress(cls, id, info):
  227. """Update the progress information for a task.
  228. This method updates both the progress message and completion percentage of a task.
  229. It handles platform-specific behavior (macOS vs others) and uses database locking
  230. when necessary to ensure thread safety.
  231. Update Rules:
  232. - progress_msg: Always appends the new message to the existing one, and trims the result to max 3000 lines.
  233. - progress: Only updates if the current progress is not -1 AND
  234. (the new progress is -1 OR greater than the existing progress),
  235. to avoid overwriting valid progress with invalid or regressive values.
  236. Args:
  237. id (str): The unique identifier of the task to update.
  238. info (dict): Dictionary containing progress information with keys:
  239. - progress_msg (str, optional): Progress message to append
  240. - progress (float, optional): Progress percentage (0.0 to 1.0)
  241. """
  242. task = cls.model.get_by_id(id)
  243. if not task:
  244. logging.warning("Update_progress error: task not found")
  245. return
  246. if os.environ.get("MACOS"):
  247. if info["progress_msg"]:
  248. progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
  249. cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
  250. if "progress" in info:
  251. prog = info["progress"]
  252. cls.model.update(progress=prog).where(
  253. (cls.model.id == id) &
  254. (
  255. (cls.model.progress != -1) &
  256. ((prog == -1) | (prog > cls.model.progress))
  257. )
  258. ).execute()
  259. return
  260. with DB.lock("update_progress", -1):
  261. if info["progress_msg"]:
  262. progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
  263. cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
  264. if "progress" in info:
  265. prog = info["progress"]
  266. cls.model.update(progress=prog).where(
  267. (cls.model.id == id) &
  268. (
  269. (cls.model.progress != -1) &
  270. ((prog == -1) | (prog > cls.model.progress))
  271. )
  272. ).execute()
  273. def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
  274. """Create and queue document processing tasks.
  275. This function creates processing tasks for a document based on its type and configuration.
  276. It handles different document types (PDF, Excel, etc.) differently and manages task
  277. chunking and configuration. It also implements task reuse optimization by checking
  278. for previously completed tasks.
  279. Args:
  280. doc (dict): Document dictionary containing metadata and configuration.
  281. bucket (str): Storage bucket name where the document is stored.
  282. name (str): File name of the document.
  283. priority (int, optional): Priority level for task queueing (default is 0).
  284. Note:
  285. - For PDF documents, tasks are created per page range based on configuration
  286. - For Excel documents, tasks are created per row range
  287. - Task digests are calculated for optimization and reuse
  288. - Previous task chunks may be reused if available
  289. """
  290. def new_task():
  291. return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000}
  292. parse_task_array = []
  293. if doc["type"] == FileType.PDF.value:
  294. file_bin = STORAGE_IMPL.get(bucket, name)
  295. do_layout = doc["parser_config"].get("layout_recognize", "DeepDOC")
  296. pages = PdfParser.total_page_number(doc["name"], file_bin)
  297. if pages is None:
  298. pages = 0
  299. page_size = doc["parser_config"].get("task_page_size") or 12
  300. if doc["parser_id"] == "paper":
  301. page_size = doc["parser_config"].get("task_page_size") or 22
  302. if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC":
  303. page_size = 10 ** 9
  304. page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
  305. for s, e in page_ranges:
  306. s -= 1
  307. s = max(0, s)
  308. e = min(e - 1, pages)
  309. for p in range(s, e, page_size):
  310. task = new_task()
  311. task["from_page"] = p
  312. task["to_page"] = min(p + page_size, e)
  313. parse_task_array.append(task)
  314. elif doc["parser_id"] == "table":
  315. file_bin = STORAGE_IMPL.get(bucket, name)
  316. rn = RAGFlowExcelParser.row_number(doc["name"], file_bin)
  317. for i in range(0, rn, 3000):
  318. task = new_task()
  319. task["from_page"] = i
  320. task["to_page"] = min(i + 3000, rn)
  321. parse_task_array.append(task)
  322. else:
  323. parse_task_array.append(new_task())
  324. chunking_config = DocumentService.get_chunking_config(doc["id"])
  325. for task in parse_task_array:
  326. hasher = xxhash.xxh64()
  327. for field in sorted(chunking_config.keys()):
  328. if field == "parser_config":
  329. for k in ["raptor", "graphrag"]:
  330. if k in chunking_config[field]:
  331. del chunking_config[field][k]
  332. hasher.update(str(chunking_config[field]).encode("utf-8"))
  333. for field in ["doc_id", "from_page", "to_page"]:
  334. hasher.update(str(task.get(field, "")).encode("utf-8"))
  335. task_digest = hasher.hexdigest()
  336. task["digest"] = task_digest
  337. task["progress"] = 0.0
  338. task["priority"] = priority
  339. prev_tasks = TaskService.get_tasks(doc["id"])
  340. ck_num = 0
  341. if prev_tasks:
  342. for task in parse_task_array:
  343. ck_num += reuse_prev_task_chunks(task, prev_tasks, chunking_config)
  344. TaskService.filter_delete([Task.doc_id == doc["id"]])
  345. pre_chunk_ids = []
  346. for pre_task in prev_tasks:
  347. if pre_task["chunk_ids"]:
  348. pre_chunk_ids.extend(pre_task["chunk_ids"].split())
  349. if pre_chunk_ids:
  350. settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
  351. chunking_config["kb_id"])
  352. DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
  353. bulk_insert_into_db(Task, parse_task_array, True)
  354. DocumentService.begin2parse(doc["id"])
  355. unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0]
  356. for unfinished_task in unfinished_task_array:
  357. assert REDIS_CONN.queue_product(
  358. get_svr_queue_name(priority), message=unfinished_task
  359. ), "Can't access Redis. Please check the Redis' status."
  360. def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
  361. """Attempt to reuse chunks from previous tasks for optimization.
  362. This function checks if chunks from previously completed tasks can be reused for
  363. the current task, which can significantly improve processing efficiency. It matches
  364. tasks based on page ranges and configuration digests.
  365. Args:
  366. task (dict): Current task dictionary to potentially reuse chunks for.
  367. prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
  368. chunking_config (dict): Configuration dictionary for chunk processing.
  369. Returns:
  370. int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
  371. Note:
  372. Chunks can only be reused if:
  373. - A previous task exists with matching page range and configuration digest
  374. - The previous task was completed successfully (progress = 1.0)
  375. - The previous task has valid chunk IDs
  376. """
  377. idx = 0
  378. while idx < len(prev_tasks):
  379. prev_task = prev_tasks[idx]
  380. if prev_task.get("from_page", 0) == task.get("from_page", 0) \
  381. and prev_task.get("digest", 0) == task.get("digest", ""):
  382. break
  383. idx += 1
  384. if idx >= len(prev_tasks):
  385. return 0
  386. prev_task = prev_tasks[idx]
  387. if prev_task["progress"] < 1.0 or not prev_task["chunk_ids"]:
  388. return 0
  389. task["chunk_ids"] = prev_task["chunk_ids"]
  390. task["progress"] = 1.0
  391. if "from_page" in task and "to_page" in task and int(task['to_page']) - int(task['from_page']) >= 10 ** 6:
  392. task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
  393. else:
  394. task["progress_msg"] = ""
  395. task["progress_msg"] = " ".join(
  396. [datetime.now().strftime("%H:%M:%S"), task["progress_msg"], "Reused previous task's chunks."])
  397. prev_task["chunk_ids"] = ""
  398. return len(task["chunk_ids"].split())
  399. def cancel_all_task_of(doc_id):
  400. for t in TaskService.query(doc_id=doc_id):
  401. try:
  402. REDIS_CONN.set(f"{t.id}-cancel", "x")
  403. except Exception as e:
  404. logging.exception(e)
  405. def has_canceled(task_id):
  406. try:
  407. if REDIS_CONN.get(f"{task_id}-cancel"):
  408. return True
  409. except Exception as e:
  410. logging.exception(e)
  411. return False