Ver código fonte

Perf: test llm before RAPTOR. (#8897)

### What problem does this PR solve?


### Type of change

- [x] Performance Improvement
tags/v0.20.0
Kevin Hu 3 meses atrás
pai
commit
ecdb1701df
Nenhuma conta vinculada ao e-mail do autor do commit

+ 15
- 0
api/utils/api_utils.py Ver arquivo

return wrapper return wrapper
return decorator return decorator



async def is_strong_enough(chat_model, embedding_model):

@timeout(30, 2)
async def _is_strong_enough():
nonlocal chat_model, embedding_model
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role":"user", "content": "Are you strong enough!?"}], {}))
if res.find("**ERROR**") >= 0:
raise Exception(res)

# Pressure test for GraphRAG task
async with trio.open_nursery() as nursery:
for _ in range(12):
nursery.start_soon(_is_strong_enough, chat_model, embedding_model)

+ 2
- 0
deepdoc/parser/figure_parser.py Ver arquivo



from PIL import Image from PIL import Image


from api.utils.api_utils import timeout
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from rag.prompts import vision_llm_figure_describe_prompt from rag.prompts import vision_llm_figure_describe_prompt


def __call__(self, **kwargs): def __call__(self, **kwargs):
callback = kwargs.get("callback", lambda prog, msg: None) callback = kwargs.get("callback", lambda prog, msg: None)


@timeout(30, 3)
def process(figure_idx, figure_binary): def process(figure_idx, figure_binary):
description_text = picture_vision_llm_chunk( description_text = picture_vision_llm_chunk(
binary=figure_binary, binary=figure_binary,

+ 2
- 11
graphrag/general/index.py Ver arquivo



from api import settings from api import settings
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import timeout
from api.utils.api_utils import timeout, is_strong_enough
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.general.community_reports_extractor import CommunityReportsExtractor from graphrag.general.community_reports_extractor import CommunityReportsExtractor
from rag.utils.redis_conn import RedisDistributedLock from rag.utils.redis_conn import RedisDistributedLock




@timeout(30, 2)
async def _is_strong_enough(chat_model, embedding_model):
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role":"user", "content": "Are you strong enough!?"}], {}))
if res.find("**ERROR**") >= 0:
raise Exception(res)



async def run_graphrag( async def run_graphrag(
row: dict, row: dict,
callback, callback,
): ):
# Pressure test for GraphRAG task # Pressure test for GraphRAG task
async with trio.open_nursery() as nursery:
for _ in range(12):
nursery.start_soon(_is_strong_enough, chat_model, embedding_model)
await is_strong_enough(chat_model, embedding_model)


start = trio.current_time() start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]

+ 3
- 1
rag/svr/task_executor.py Ver arquivo

import threading import threading
import time import time


from api.utils.api_utils import timeout
from api.utils.api_utils import timeout, is_strong_enough
from api.utils.log_utils import init_root_logger, get_project_base_directory from api.utils.log_utils import init_root_logger, get_project_base_directory
from graphrag.general.index import run_graphrag from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache


@timeout(3600) @timeout(3600)
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
# Pressure test for GraphRAG task
await is_strong_enough(chat_mdl, embd_mdl)
chunks = [] chunks = []
vctr_nm = "q_%d_vec"%vector_size vctr_nm = "q_%d_vec"%vector_size
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],

Carregando…
Cancelar
Salvar