Browse Source

Perf: set timeout of some steps in KG. (#8873)

### What problem does this PR solve?

### Type of change


- [x] Performance Improvement
tags/v0.20.0
Kevin Hu 3 months ago
parent
commit
fbd115773b
No account linked to committer's email address

+ 4
- 4
api/db/services/dialog_service.py View File

@@ -36,7 +36,7 @@ from api.utils import current_timestamp, datetime_format
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
from rag.utils import num_tokens_from_string, rmSpace
from rag.utils.tavily_conn import Tavily

@@ -97,7 +97,7 @@ class DialogService(CommonService):


def chat_solo(dialog, messages, stream=True):
if llm_id2llm_type(dialog.llm_id) == "image2text":
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
@@ -139,7 +139,7 @@ def get_models(dialog):
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_list[0])

if llm_id2llm_type(dialog.llm_id) == "image2text":
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
@@ -198,7 +198,7 @@ def chat(dialog, messages, stream=True, **kwargs):

chat_start_ts = timer()

if llm_id2llm_type(dialog.llm_id) == "image2text":
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)

+ 2
- 0
api/db/services/document_service.py View File

@@ -583,6 +583,8 @@ class DocumentService(CommonService):
info["progress"] = prg
if msg:
info["progress_msg"] = msg
if msg.endswith("created task graphrag") or msg.endswith("created task raptor"):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
cls.update_by_id(d["id"], info)

+ 9
- 0
api/db/services/llm_service.py View File

@@ -214,6 +214,15 @@ class TenantLLMService(CommonService):
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs)

@staticmethod
def llm_id2llm_type(llm_id: str) ->str|None:
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]


class LLMBundle:
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):

+ 1
- 1
api/settings.py View File

@@ -26,7 +26,6 @@ import rag.utils.opensearch_conn
from api.constants import RAG_FLOW_SERVICE_NAME
from api.utils import decrypt_database_config, get_base_config
from api.utils.file_utils import get_project_base_directory
from graphrag import search as kg_search
from rag.nlp import search

LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
@@ -169,6 +168,7 @@ def init_settings():
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")

retrievaler = search.Dealer(docStoreConn)
from graphrag import search as kg_search
kg_retrievaler = kg_search.KGSearch(docStoreConn)

if int(os.environ.get("SANDBOX_ENABLED", "0")):

+ 1
- 3
api/utils/api_utils.py View File

@@ -31,8 +31,6 @@ from urllib.parse import quote, urlencode
from uuid import uuid1

import trio

from api.db.db_models import MCPServer
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions


@@ -570,7 +568,7 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
return transformed_data


def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
results = {}
tool_call_sessions = []
try:

+ 3
- 3
graphrag/general/index.py View File

@@ -124,7 +124,7 @@ async def run_graphrag(
return


@timeout(60*60*2)
@timeout(60*60, 1)
async def generate_subgraph(
extractor: Extractor,
tenant_id: str,
@@ -229,7 +229,7 @@ async def merge_subgraph(
return new_graph


@timeout(60*60)
@timeout(60*30, 1)
async def resolve_entities(
graph,
subgraph_nodes: set[str],
@@ -255,7 +255,7 @@ async def resolve_entities(
callback(msg=f"Graph resolution done in {now - start:.2f}s.")


@timeout(60*30)
@timeout(60*30, 1)
async def extract_community(
graph,
tenant_id: str,

+ 3
- 2
graphrag/utils.py View File

@@ -17,13 +17,12 @@ from typing import Any, Callable
import os
import trio
from typing import Set, Tuple

import networkx as nx
import numpy as np
import xxhash
from networkx.readwrite import json_graph
import dataclasses
from api.utils.api_utils import timeout
from api import settings
from api.utils import get_uuid
from rag.nlp import search, rag_tokenizer
@@ -305,6 +304,7 @@ def chunk_id(chunk):
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()


@timeout(1, 3)
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
chunk = {
"id": get_uuid(),
@@ -357,6 +357,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
return res


@timeout(1, 3)
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
chunk = {
"id": get_uuid(),

+ 4
- 15
rag/prompts.py View File

@@ -22,7 +22,6 @@ from collections import defaultdict
import jinja2
import json_repair

from api import settings
from rag.prompt_template import load_prompt
from rag.settings import TAG_FLD
from rag.utils import encoder, num_tokens_from_string
@@ -51,18 +50,6 @@ def chunks_format(reference):
]


def llm_id2llm_type(llm_id):
from api.db.services.llm_service import TenantLLMService

llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)

llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]


def message_fit_in(msg, max_length=4000):
def count():
nonlocal msg
@@ -188,8 +175,9 @@ def question_proposal(chat_mdl, content, topn=3):
def full_question(tenant_id, llm_id, messages, language=None):
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.llm_service import TenantLLMService

if llm_id2llm_type(llm_id) == "image2text":
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
@@ -220,8 +208,9 @@ def full_question(tenant_id, llm_id, messages, language=None):
def cross_languages(tenant_id, llm_id, query, languages=[]):
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.llm_service import TenantLLMService

if llm_id and llm_id2llm_type(llm_id) == "image2text":
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)

+ 1
- 1
rag/svr/task_executor.py View File

@@ -506,7 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count


@timeout(60*60*1.5)
@timeout(60*60, 1)
async def do_handle_task(task):
task_id = task["id"]
task_from_page = task["from_page"]

Loading…
Cancel
Save