### What problem does this PR solve? ### Type of change - [x] Performance Improvementtags/v0.20.0
| from rag.app.resume import forbidden_select_fields4resume | from rag.app.resume import forbidden_select_fields4resume | ||||
| from rag.app.tag import label_question | from rag.app.tag import label_question | ||||
| from rag.nlp.search import index_name | from rag.nlp.search import index_name | ||||
| from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in | |||||
| from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in | |||||
| from rag.utils import num_tokens_from_string, rmSpace | from rag.utils import num_tokens_from_string, rmSpace | ||||
| from rag.utils.tavily_conn import Tavily | from rag.utils.tavily_conn import Tavily | ||||
| def chat_solo(dialog, messages, stream=True): | def chat_solo(dialog, messages, stream=True): | ||||
| if llm_id2llm_type(dialog.llm_id) == "image2text": | |||||
| if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": | |||||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | ||||
| else: | else: | ||||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | ||||
| if not embd_mdl: | if not embd_mdl: | ||||
| raise LookupError("Embedding model(%s) not found" % embedding_list[0]) | raise LookupError("Embedding model(%s) not found" % embedding_list[0]) | ||||
| if llm_id2llm_type(dialog.llm_id) == "image2text": | |||||
| if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": | |||||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | ||||
| else: | else: | ||||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | ||||
| chat_start_ts = timer() | chat_start_ts = timer() | ||||
| if llm_id2llm_type(dialog.llm_id) == "image2text": | |||||
| if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": | |||||
| llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | ||||
| else: | else: | ||||
| llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) |
| info["progress"] = prg | info["progress"] = prg | ||||
| if msg: | if msg: | ||||
| info["progress_msg"] = msg | info["progress_msg"] = msg | ||||
| if msg.endswith("created task graphrag") or msg.endswith("created task raptor"): | |||||
| info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority) | |||||
| else: | else: | ||||
| info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) | info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) | ||||
| cls.update_by_id(d["id"], info) | cls.update_by_id(d["id"], info) |
| objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() | objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() | ||||
| return list(objs) | return list(objs) | ||||
| @staticmethod | |||||
| def llm_id2llm_type(llm_id: str) ->str|None: | |||||
| llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) | |||||
| llm_factories = settings.FACTORY_LLM_INFOS | |||||
| for llm_factory in llm_factories: | |||||
| for llm in llm_factory["llm"]: | |||||
| if llm_id == llm["llm_name"]: | |||||
| return llm["model_type"].strip(",")[-1] | |||||
| class LLMBundle: | class LLMBundle: | ||||
| def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"): | def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"): |
| from api.constants import RAG_FLOW_SERVICE_NAME | from api.constants import RAG_FLOW_SERVICE_NAME | ||||
| from api.utils import decrypt_database_config, get_base_config | from api.utils import decrypt_database_config, get_base_config | ||||
| from api.utils.file_utils import get_project_base_directory | from api.utils.file_utils import get_project_base_directory | ||||
| from graphrag import search as kg_search | |||||
| from rag.nlp import search | from rag.nlp import search | ||||
| LIGHTEN = int(os.environ.get("LIGHTEN", "0")) | LIGHTEN = int(os.environ.get("LIGHTEN", "0")) | ||||
| raise Exception(f"Not supported doc engine: {DOC_ENGINE}") | raise Exception(f"Not supported doc engine: {DOC_ENGINE}") | ||||
| retrievaler = search.Dealer(docStoreConn) | retrievaler = search.Dealer(docStoreConn) | ||||
| from graphrag import search as kg_search | |||||
| kg_retrievaler = kg_search.KGSearch(docStoreConn) | kg_retrievaler = kg_search.KGSearch(docStoreConn) | ||||
| if int(os.environ.get("SANDBOX_ENABLED", "0")): | if int(os.environ.get("SANDBOX_ENABLED", "0")): |
| from uuid import uuid1 | from uuid import uuid1 | ||||
| import trio | import trio | ||||
| from api.db.db_models import MCPServer | |||||
| from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions | from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions | ||||
| return transformed_data | return transformed_data | ||||
| def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]: | |||||
| def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]: | |||||
| results = {} | results = {} | ||||
| tool_call_sessions = [] | tool_call_sessions = [] | ||||
| try: | try: |
| return | return | ||||
| @timeout(60*60*2) | |||||
| @timeout(60*60, 1) | |||||
| async def generate_subgraph( | async def generate_subgraph( | ||||
| extractor: Extractor, | extractor: Extractor, | ||||
| tenant_id: str, | tenant_id: str, | ||||
| return new_graph | return new_graph | ||||
| @timeout(60*60) | |||||
| @timeout(60*30, 1) | |||||
| async def resolve_entities( | async def resolve_entities( | ||||
| graph, | graph, | ||||
| subgraph_nodes: set[str], | subgraph_nodes: set[str], | ||||
| callback(msg=f"Graph resolution done in {now - start:.2f}s.") | callback(msg=f"Graph resolution done in {now - start:.2f}s.") | ||||
| @timeout(60*30) | |||||
| @timeout(60*30, 1) | |||||
| async def extract_community( | async def extract_community( | ||||
| graph, | graph, | ||||
| tenant_id: str, | tenant_id: str, |
| import os | import os | ||||
| import trio | import trio | ||||
| from typing import Set, Tuple | from typing import Set, Tuple | ||||
| import networkx as nx | import networkx as nx | ||||
| import numpy as np | import numpy as np | ||||
| import xxhash | import xxhash | ||||
| from networkx.readwrite import json_graph | from networkx.readwrite import json_graph | ||||
| import dataclasses | import dataclasses | ||||
| from api.utils.api_utils import timeout | |||||
| from api import settings | from api import settings | ||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from rag.nlp import search, rag_tokenizer | from rag.nlp import search, rag_tokenizer | ||||
| return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() | return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() | ||||
| @timeout(1, 3) | |||||
| async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): | async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): | ||||
| chunk = { | chunk = { | ||||
| "id": get_uuid(), | "id": get_uuid(), | ||||
| return res | return res | ||||
| @timeout(1, 3) | |||||
| async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): | async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): | ||||
| chunk = { | chunk = { | ||||
| "id": get_uuid(), | "id": get_uuid(), |
| import jinja2 | import jinja2 | ||||
| import json_repair | import json_repair | ||||
| from api import settings | |||||
| from rag.prompt_template import load_prompt | from rag.prompt_template import load_prompt | ||||
| from rag.settings import TAG_FLD | from rag.settings import TAG_FLD | ||||
| from rag.utils import encoder, num_tokens_from_string | from rag.utils import encoder, num_tokens_from_string | ||||
| ] | ] | ||||
| def llm_id2llm_type(llm_id): | |||||
| from api.db.services.llm_service import TenantLLMService | |||||
| llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) | |||||
| llm_factories = settings.FACTORY_LLM_INFOS | |||||
| for llm_factory in llm_factories: | |||||
| for llm in llm_factory["llm"]: | |||||
| if llm_id == llm["llm_name"]: | |||||
| return llm["model_type"].strip(",")[-1] | |||||
| def message_fit_in(msg, max_length=4000): | def message_fit_in(msg, max_length=4000): | ||||
| def count(): | def count(): | ||||
| nonlocal msg | nonlocal msg | ||||
| def full_question(tenant_id, llm_id, messages, language=None): | def full_question(tenant_id, llm_id, messages, language=None): | ||||
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.db.services.llm_service import TenantLLMService | |||||
| if llm_id2llm_type(llm_id) == "image2text": | |||||
| if TenantLLMService.llm_id2llm_type(llm_id) == "image2text": | |||||
| chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) | chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) | ||||
| else: | else: | ||||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) | chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) | ||||
| def cross_languages(tenant_id, llm_id, query, languages=[]): | def cross_languages(tenant_id, llm_id, query, languages=[]): | ||||
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.db.services.llm_service import TenantLLMService | |||||
| if llm_id and llm_id2llm_type(llm_id) == "image2text": | |||||
| if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text": | |||||
| chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) | chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) | ||||
| else: | else: | ||||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) | chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) |
| return res, tk_count | return res, tk_count | ||||
| @timeout(60*60*1.5) | |||||
| @timeout(60*60, 1) | |||||
| async def do_handle_task(task): | async def do_handle_task(task): | ||||
| task_id = task["id"] | task_id = task["id"] | ||||
| task_from_page = task["from_page"] | task_from_page = task["from_page"] |