浏览代码

Perf: set timeout of some steps in KG. (#8873)

### What problem does this PR solve?

### Type of change


- [x] Performance Improvement
tags/v0.20.0
Kevin Hu 3 个月前
父节点
当前提交
fbd115773b
没有帐户链接到提交者的电子邮件

+ 4
- 4
api/db/services/dialog_service.py 查看文件

from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
from rag.utils import num_tokens_from_string, rmSpace from rag.utils import num_tokens_from_string, rmSpace
from rag.utils.tavily_conn import Tavily from rag.utils.tavily_conn import Tavily






def chat_solo(dialog, messages, stream=True): def chat_solo(dialog, messages, stream=True):
if llm_id2llm_type(dialog.llm_id) == "image2text":
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else: else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
if not embd_mdl: if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_list[0]) raise LookupError("Embedding model(%s) not found" % embedding_list[0])


if llm_id2llm_type(dialog.llm_id) == "image2text":
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else: else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)


chat_start_ts = timer() chat_start_ts = timer()


if llm_id2llm_type(dialog.llm_id) == "image2text":
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else: else:
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)

+ 2
- 0
api/db/services/document_service.py 查看文件

info["progress"] = prg info["progress"] = prg
if msg: if msg:
info["progress_msg"] = msg info["progress_msg"] = msg
if msg.endswith("created task graphrag") or msg.endswith("created task raptor"):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
else: else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
cls.update_by_id(d["id"], info) cls.update_by_id(d["id"], info)

+ 9
- 0
api/db/services/llm_service.py 查看文件

objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs) return list(objs)


@staticmethod
def llm_id2llm_type(llm_id: str) ->str|None:
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]



class LLMBundle: class LLMBundle:
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"): def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):

+ 1
- 1
api/settings.py 查看文件

from api.constants import RAG_FLOW_SERVICE_NAME from api.constants import RAG_FLOW_SERVICE_NAME
from api.utils import decrypt_database_config, get_base_config from api.utils import decrypt_database_config, get_base_config
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from graphrag import search as kg_search
from rag.nlp import search from rag.nlp import search


LIGHTEN = int(os.environ.get("LIGHTEN", "0")) LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
raise Exception(f"Not supported doc engine: {DOC_ENGINE}") raise Exception(f"Not supported doc engine: {DOC_ENGINE}")


retrievaler = search.Dealer(docStoreConn) retrievaler = search.Dealer(docStoreConn)
from graphrag import search as kg_search
kg_retrievaler = kg_search.KGSearch(docStoreConn) kg_retrievaler = kg_search.KGSearch(docStoreConn)


if int(os.environ.get("SANDBOX_ENABLED", "0")): if int(os.environ.get("SANDBOX_ENABLED", "0")):

+ 1
- 3
api/utils/api_utils.py 查看文件

from uuid import uuid1 from uuid import uuid1


import trio import trio

from api.db.db_models import MCPServer
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions




return transformed_data return transformed_data




def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
results = {} results = {}
tool_call_sessions = [] tool_call_sessions = []
try: try:

+ 3
- 3
graphrag/general/index.py 查看文件

return return




@timeout(60*60*2)
@timeout(60*60, 1)
async def generate_subgraph( async def generate_subgraph(
extractor: Extractor, extractor: Extractor,
tenant_id: str, tenant_id: str,
return new_graph return new_graph




@timeout(60*60)
@timeout(60*30, 1)
async def resolve_entities( async def resolve_entities(
graph, graph,
subgraph_nodes: set[str], subgraph_nodes: set[str],
callback(msg=f"Graph resolution done in {now - start:.2f}s.") callback(msg=f"Graph resolution done in {now - start:.2f}s.")




@timeout(60*30)
@timeout(60*30, 1)
async def extract_community( async def extract_community(
graph, graph,
tenant_id: str, tenant_id: str,

+ 3
- 2
graphrag/utils.py 查看文件

import os import os
import trio import trio
from typing import Set, Tuple from typing import Set, Tuple

import networkx as nx import networkx as nx
import numpy as np import numpy as np
import xxhash import xxhash
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
import dataclasses import dataclasses
from api.utils.api_utils import timeout
from api import settings from api import settings
from api.utils import get_uuid from api.utils import get_uuid
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()




@timeout(1, 3)
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
chunk = { chunk = {
"id": get_uuid(), "id": get_uuid(),
return res return res




@timeout(1, 3)
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
chunk = { chunk = {
"id": get_uuid(), "id": get_uuid(),

+ 4
- 15
rag/prompts.py 查看文件

import jinja2 import jinja2
import json_repair import json_repair


from api import settings
from rag.prompt_template import load_prompt from rag.prompt_template import load_prompt
from rag.settings import TAG_FLD from rag.settings import TAG_FLD
from rag.utils import encoder, num_tokens_from_string from rag.utils import encoder, num_tokens_from_string
] ]




def llm_id2llm_type(llm_id):
from api.db.services.llm_service import TenantLLMService

llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)

llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]


def message_fit_in(msg, max_length=4000): def message_fit_in(msg, max_length=4000):
def count(): def count():
nonlocal msg nonlocal msg
def full_question(tenant_id, llm_id, messages, language=None): def full_question(tenant_id, llm_id, messages, language=None):
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.llm_service import TenantLLMService


if llm_id2llm_type(llm_id) == "image2text":
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else: else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
def cross_languages(tenant_id, llm_id, query, languages=[]): def cross_languages(tenant_id, llm_id, query, languages=[]):
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.llm_service import TenantLLMService


if llm_id and llm_id2llm_type(llm_id) == "image2text":
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else: else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)

+ 1
- 1
rag/svr/task_executor.py 查看文件

return res, tk_count return res, tk_count




@timeout(60*60*1.5)
@timeout(60*60, 1)
async def do_handle_task(task): async def do_handle_task(task):
task_id = task["id"] task_id = task["id"]
task_from_page = task["from_page"] task_from_page = task["from_page"]

正在加载...
取消
保存