### What problem does this PR solve? ### Type of change - [x] Performance Improvementtags/v0.20.0
| @@ -36,7 +36,7 @@ from api.utils import current_timestamp, datetime_format | |||
| from rag.app.resume import forbidden_select_fields4resume | |||
| from rag.app.tag import label_question | |||
| from rag.nlp.search import index_name | |||
| from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in | |||
| from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in | |||
| from rag.utils import num_tokens_from_string, rmSpace | |||
| from rag.utils.tavily_conn import Tavily | |||
| @@ -97,7 +97,7 @@ class DialogService(CommonService): | |||
| def chat_solo(dialog, messages, stream=True): | |||
| if llm_id2llm_type(dialog.llm_id) == "image2text": | |||
| if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": | |||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | |||
| else: | |||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | |||
| @@ -139,7 +139,7 @@ def get_models(dialog): | |||
| if not embd_mdl: | |||
| raise LookupError("Embedding model(%s) not found" % embedding_list[0]) | |||
| if llm_id2llm_type(dialog.llm_id) == "image2text": | |||
| if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": | |||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | |||
| else: | |||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | |||
| @@ -198,7 +198,7 @@ def chat(dialog, messages, stream=True, **kwargs): | |||
| chat_start_ts = timer() | |||
| if llm_id2llm_type(dialog.llm_id) == "image2text": | |||
| if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": | |||
| llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | |||
| else: | |||
| llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | |||
| @@ -583,6 +583,8 @@ class DocumentService(CommonService): | |||
| info["progress"] = prg | |||
| if msg: | |||
| info["progress_msg"] = msg | |||
| if msg.endswith("created task graphrag") or msg.endswith("created task raptor"): | |||
| info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority) | |||
| else: | |||
| info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) | |||
| cls.update_by_id(d["id"], info) | |||
| @@ -214,6 +214,15 @@ class TenantLLMService(CommonService): | |||
| objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() | |||
| return list(objs) | |||
| @staticmethod | |||
| def llm_id2llm_type(llm_id: str) ->str|None: | |||
| llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) | |||
| llm_factories = settings.FACTORY_LLM_INFOS | |||
| for llm_factory in llm_factories: | |||
| for llm in llm_factory["llm"]: | |||
| if llm_id == llm["llm_name"]: | |||
| return llm["model_type"].strip(",")[-1] | |||
| class LLMBundle: | |||
| def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"): | |||
| @@ -26,7 +26,6 @@ import rag.utils.opensearch_conn | |||
| from api.constants import RAG_FLOW_SERVICE_NAME | |||
| from api.utils import decrypt_database_config, get_base_config | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from graphrag import search as kg_search | |||
| from rag.nlp import search | |||
| LIGHTEN = int(os.environ.get("LIGHTEN", "0")) | |||
| @@ -169,6 +168,7 @@ def init_settings(): | |||
| raise Exception(f"Not supported doc engine: {DOC_ENGINE}") | |||
| retrievaler = search.Dealer(docStoreConn) | |||
| from graphrag import search as kg_search | |||
| kg_retrievaler = kg_search.KGSearch(docStoreConn) | |||
| if int(os.environ.get("SANDBOX_ENABLED", "0")): | |||
| @@ -31,8 +31,6 @@ from urllib.parse import quote, urlencode | |||
| from uuid import uuid1 | |||
| import trio | |||
| from api.db.db_models import MCPServer | |||
| from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions | |||
| @@ -570,7 +568,7 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict: | |||
| return transformed_data | |||
| def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]: | |||
| def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]: | |||
| results = {} | |||
| tool_call_sessions = [] | |||
| try: | |||
| @@ -124,7 +124,7 @@ async def run_graphrag( | |||
| return | |||
| @timeout(60*60*2) | |||
| @timeout(60*60, 1) | |||
| async def generate_subgraph( | |||
| extractor: Extractor, | |||
| tenant_id: str, | |||
| @@ -229,7 +229,7 @@ async def merge_subgraph( | |||
| return new_graph | |||
| @timeout(60*60) | |||
| @timeout(60*30, 1) | |||
| async def resolve_entities( | |||
| graph, | |||
| subgraph_nodes: set[str], | |||
| @@ -255,7 +255,7 @@ async def resolve_entities( | |||
| callback(msg=f"Graph resolution done in {now - start:.2f}s.") | |||
| @timeout(60*30) | |||
| @timeout(60*30, 1) | |||
| async def extract_community( | |||
| graph, | |||
| tenant_id: str, | |||
| @@ -17,13 +17,12 @@ from typing import Any, Callable | |||
| import os | |||
| import trio | |||
| from typing import Set, Tuple | |||
| import networkx as nx | |||
| import numpy as np | |||
| import xxhash | |||
| from networkx.readwrite import json_graph | |||
| import dataclasses | |||
| from api.utils.api_utils import timeout | |||
| from api import settings | |||
| from api.utils import get_uuid | |||
| from rag.nlp import search, rag_tokenizer | |||
| @@ -305,6 +304,7 @@ def chunk_id(chunk): | |||
| return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() | |||
| @timeout(1, 3) | |||
| async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): | |||
| chunk = { | |||
| "id": get_uuid(), | |||
| @@ -357,6 +357,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): | |||
| return res | |||
| @timeout(1, 3) | |||
| async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): | |||
| chunk = { | |||
| "id": get_uuid(), | |||
| @@ -22,7 +22,6 @@ from collections import defaultdict | |||
| import jinja2 | |||
| import json_repair | |||
| from api import settings | |||
| from rag.prompt_template import load_prompt | |||
| from rag.settings import TAG_FLD | |||
| from rag.utils import encoder, num_tokens_from_string | |||
| @@ -51,18 +50,6 @@ def chunks_format(reference): | |||
| ] | |||
| def llm_id2llm_type(llm_id): | |||
| from api.db.services.llm_service import TenantLLMService | |||
| llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) | |||
| llm_factories = settings.FACTORY_LLM_INFOS | |||
| for llm_factory in llm_factories: | |||
| for llm in llm_factory["llm"]: | |||
| if llm_id == llm["llm_name"]: | |||
| return llm["model_type"].strip(",")[-1] | |||
| def message_fit_in(msg, max_length=4000): | |||
| def count(): | |||
| nonlocal msg | |||
| @@ -188,8 +175,9 @@ def question_proposal(chat_mdl, content, topn=3): | |||
| def full_question(tenant_id, llm_id, messages, language=None): | |||
| from api.db import LLMType | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.llm_service import TenantLLMService | |||
| if llm_id2llm_type(llm_id) == "image2text": | |||
| if TenantLLMService.llm_id2llm_type(llm_id) == "image2text": | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) | |||
| else: | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) | |||
| @@ -220,8 +208,9 @@ def full_question(tenant_id, llm_id, messages, language=None): | |||
| def cross_languages(tenant_id, llm_id, query, languages=[]): | |||
| from api.db import LLMType | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.llm_service import TenantLLMService | |||
| if llm_id and llm_id2llm_type(llm_id) == "image2text": | |||
| if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text": | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) | |||
| else: | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) | |||
| @@ -506,7 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): | |||
| return res, tk_count | |||
| @timeout(60*60*1.5) | |||
| @timeout(60*60, 1) | |||
| async def do_handle_task(task): | |||
| task_id = task["id"] | |||
| task_from_page = task["from_page"] | |||