浏览代码

Pref: use redis to check if canceled. (#8853)

### What problem does this PR solve?

### Type of change

- [x] Performance Improvement
tags/v0.20.0
Kevin Hu 3 个月前
父节点
当前提交
aa4a725529
没有帐户链接到提交者的电子邮件
共有 3 个文件被更改,包括 27 次插入7 次删除
  1. 4
    1
      api/apps/document_app.py
  2. 17
    0
      api/db/services/task_service.py
  3. 6
    6
      rag/svr/task_executor.py

+ 4
- 1
api/apps/document_app.py 查看文件

from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import ( from api.utils.api_utils import (
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)


if str(req["run"]) == TaskStatus.CANCEL.value:
cancel_all_task_of(id)

if str(req["run"]) == TaskStatus.RUNNING.value: if str(req["run"]) == TaskStatus.RUNNING.value:
e, doc = DocumentService.get_by_id(id) e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict() doc = doc.to_dict()

+ 17
- 0
api/db/services/task_service.py 查看文件

prev_task["chunk_ids"] = "" prev_task["chunk_ids"] = ""


return len(task["chunk_ids"].split()) return len(task["chunk_ids"].split())


def cancel_all_task_of(doc_id):
for t in TaskService.query(doc_id=doc_id):
try:
REDIS_CONN.set(f"{t.id}-cancel", "x")
except Exception as e:
logging.exception(e)


def has_canceled(task_id):
try:
if REDIS_CONN.get(f"{task_id}-cancel"):
return True
except Exception as e:
logging.exception(e)
return False

+ 6
- 6
rag/svr/task_executor.py 查看文件

from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService
from api.db.services.task_service import TaskService, has_canceled
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api import settings from api import settings
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
try: try:
if prog is not None and prog < 0: if prog is not None and prog < 0:
msg = "[ERROR]" + msg msg = "[ERROR]" + msg
cancel = TaskService.do_cancel(task_id)
cancel = has_canceled(task_id)


if cancel: if cancel:
msg += " [Canceled]" msg += " [Canceled]"
canceled = False canceled = False
task = TaskService.get_task(msg["id"]) task = TaskService.get_task(msg["id"])
if task: if task:
canceled = DocumentService.do_cancel(task["doc_id"])
canceled = has_canceled(task["id"])
if not task or canceled: if not task or canceled:
state = "is unknown" if not task else "has been cancelled" state = "is unknown" if not task else "has been cancelled"
FAILED_TASKS += 1 FAILED_TASKS += 1


docs_to_tag = [] docs_to_tag = []
for d in docs: for d in docs:
task_canceled = DocumentService.do_cancel(task["doc_id"])
task_canceled = has_canceled(task["id"])
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return
progress_callback(-1, msg=error_message) progress_callback(-1, msg=error_message)
raise Exception(error_message) raise Exception(error_message)


task_canceled = DocumentService.do_cancel(task_doc_id)
task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return


for b in range(0, len(chunks), DOC_BULK_SIZE): for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = DocumentService.do_cancel(task_doc_id)
task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return

正在加载...
取消
保存