Przeglądaj źródła

Perf: make `do_cancel` quicker. (#8846)

### What problem does this PR solve?

### Type of change

- [x] Performance Improvement
tags/v0.20.0
Kevin Hu 3 miesięcy temu
rodzic
commit
24c41d2a61
No account linked to committer's email address
3 zmienionych plików z 40 dodań i 33 usunięć
  1. 35
    29
      api/utils/api_utils.py
  2. 1
    0
      rag/raptor.py
  3. 4
    4
      rag/svr/task_executor.py

+ 35
- 29
api/utils/api_utils.py Wyświetl plik

@@ -606,6 +606,7 @@ TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(
seconds: float |int = None,
attempts: int = 2,
*,
exception: Optional[TimeoutException] = None,
on_timeout: Optional[OnTimeoutCallback] = None
@@ -625,41 +626,46 @@ def timeout(
thread.daemon = True
thread.start()

try:
result = result_queue.get(timeout=seconds)
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds")
for a in range(attempts):
try:
result = result_queue.get(timeout=seconds)
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
pass
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")

@wraps(func)
async def async_wrapper(*args, **kwargs) -> Any:
if seconds is None:
return await func(*args, **kwargs)

try:
with trio.fail_after(seconds):
return await func(*args, **kwargs)
except trio.TooSlowError:
if on_timeout is not None:
if callable(on_timeout):
result = on_timeout()
if isinstance(result, Coroutine):
return await result
return result
return on_timeout

if exception is None:
raise TimeoutError(f"Operation timed out after {seconds} seconds")

if isinstance(exception, BaseException):
raise exception

if isinstance(exception, type) and issubclass(exception, BaseException):
raise exception(f"Operation timed out after {seconds} seconds")

raise RuntimeError("Invalid exception type provided")
for a in range(attempts):
try:
with trio.fail_after(seconds):
return await func(*args, **kwargs)
except trio.TooSlowError:
if a < attempts -1:
continue
if on_timeout is not None:
if callable(on_timeout):
result = on_timeout()
if isinstance(result, Coroutine):
return await result
return result
return on_timeout

if exception is None:
raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.")

if isinstance(exception, BaseException):
raise exception

if isinstance(exception, type) and issubclass(exception, BaseException):
raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.")

raise RuntimeError("Invalid exception type provided")

if asyncio.iscoroutinefunction(func):
return async_wrapper

+ 1
- 0
rag/raptor.py Wyświetl plik

@@ -42,6 +42,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._prompt = prompt
self._max_token = max_token

@timeout(60)
async def _chat(self, system, history, gen_conf):
response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
if response:

+ 4
- 4
rag/svr/task_executor.py Wyświetl plik

@@ -214,7 +214,7 @@ async def collect():
canceled = False
task = TaskService.get_task(msg["id"])
if task:
canceled = TaskService.do_cancel(task["id"])
canceled = DocumentService.do_cancel(task["doc_id"])
if not task or canceled:
state = "is unknown" if not task else "has been cancelled"
FAILED_TASKS += 1
@@ -382,7 +382,7 @@ async def build_chunks(task, progress_callback):

docs_to_tag = []
for d in docs:
task_canceled = TaskService.do_cancel(task["id"])
task_canceled = DocumentService.do_cancel(task["doc_id"])
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
@@ -531,7 +531,7 @@ async def do_handle_task(task):
progress_callback(-1, msg=error_message)
raise Exception(error_message)

task_canceled = TaskService.do_cancel(task_id)
task_canceled = DocumentService.do_cancel(task_doc_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
@@ -609,7 +609,7 @@ async def do_handle_task(task):

for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = TaskService.do_cancel(task_id)
task_canceled = DocumentService.do_cancel(task_doc_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return

Ładowanie…
Anuluj
Zapisz