Browse Source

Perf: set timeout for building chunks. (#8940)

### What problem does this PR solve?


### Type of change

- [x] Performance Improvement
tags/v0.20.0
Kevin Hu 3 months ago
parent
commit
c783d90ba3
No account linked to committer's email address
3 changed files with 14 additions and 10 deletions
  1. 8
    6
      api/utils/api_utils.py
  2. 1
    4
      graphrag/general/index.py
  3. 5
    0
      rag/svr/task_executor.py

+ 8
- 6
api/utils/api_utils.py View File

@timeout(30, 2) @timeout(30, 2)
async def _is_strong_enough(): async def _is_strong_enough():
nonlocal chat_model, embedding_model nonlocal chat_model, embedding_model
with trio.fail_after(3):
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
with trio.fail_after(30):
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role":"user", "content": "Are you strong enough!?"}], {}))
if res.find("**ERROR**") >= 0:
raise Exception(res)
if embedding_model:
with trio.fail_after(3):
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
if chat_model:
with trio.fail_after(30):
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role":"user", "content": "Are you strong enough!?"}], {}))
if res.find("**ERROR**") >= 0:
raise Exception(res)


# Pressure test for GraphRAG task # Pressure test for GraphRAG task
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:

+ 1
- 4
graphrag/general/index.py View File



from api import settings from api import settings
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import timeout, is_strong_enough
from api.utils.api_utils import timeout
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.general.community_reports_extractor import CommunityReportsExtractor from graphrag.general.community_reports_extractor import CommunityReportsExtractor
embedding_model, embedding_model,
callback, callback,
): ):
# Pressure test for GraphRAG task
await is_strong_enough(chat_model, embedding_model)

start = trio.current_time() start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = [] chunks = []

+ 5
- 0
rag/svr/task_executor.py View File

except Exception: except Exception:
logging.exception(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}, got exception") logging.exception(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}, got exception")



async def collect(): async def collect():
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
global UNACKED_ITERATOR global UNACKED_ITERATOR
return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name)) return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))




@timeout(60*40, 1)
async def build_chunks(task, progress_callback): async def build_chunks(task, progress_callback):
if task["size"] > DOC_MAXIMUM_SIZE: if task["size"] > DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
try: try:
# bind embedding model # bind embedding model
embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language) embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
await is_strong_enough(None, embedding_model)
vts, _ = embedding_model.encode(["ok"]) vts, _ = embedding_model.encode(["ok"])
vector_size = len(vts[0]) vector_size = len(vts[0])
except Exception as e: except Exception as e:
if task.get("task_type", "") == "raptor": if task.get("task_type", "") == "raptor":
# bind LLM for raptor # bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await is_strong_enough(chat_model, None)
# run RAPTOR # run RAPTOR
async with kg_limiter: async with kg_limiter:
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
graphrag_conf = task["kb_parser_config"].get("graphrag", {}) graphrag_conf = task["kb_parser_config"].get("graphrag", {})
start_ts = timer() start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await is_strong_enough(chat_model, None)
with_resolution = graphrag_conf.get("resolution", False) with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False) with_community = graphrag_conf.get("community", False)
async with kg_limiter: async with kg_limiter:

Loading…
Cancel
Save