### What problem does this PR solve? ### Type of change - [x] Refactoringtags/v0.20.2
| import json_repair | import json_repair | ||||
| from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta | from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta | ||||
| from api.db.services.llm_service import LLMBundle, TenantLLMService | |||||
| from api.db.services.llm_service import LLMBundle | |||||
| from api.db.services.tenant_llm_service import TenantLLMService | |||||
| from api.db.services.mcp_server_service import MCPServerService | from api.db.services.mcp_server_service import MCPServerService | ||||
| from api.utils.api_utils import timeout | from api.utils.api_utils import timeout | ||||
| from rag.prompts import message_fit_in | from rag.prompts import message_fit_in |
| from functools import partial | from functools import partial | ||||
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.llm_service import LLMBundle, TenantLLMService | |||||
| from api.db.services.llm_service import LLMBundle | |||||
| from api.db.services.tenant_llm_service import TenantLLMService | |||||
| from agent.component.base import ComponentBase, ComponentParamBase | from agent.component.base import ComponentBase, ComponentParamBase | ||||
| from api.utils.api_utils import timeout | from api.utils.api_utils import timeout | ||||
| from rag.prompts import message_fit_in, citation_prompt | from rag.prompts import message_fit_in, citation_prompt |
| from api.db.services.conversation_service import ConversationService, structure_answer | from api.db.services.conversation_service import ConversationService, structure_answer | ||||
| from api.db.services.dialog_service import DialogService, ask, chat | from api.db.services.dialog_service import DialogService, ask, chat | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMBundle, TenantService | |||||
| from api.db.services.user_service import UserTenantService | |||||
| from api.db.services.llm_service import LLMBundle | |||||
| from api.db.services.user_service import UserTenantService, TenantService | |||||
| from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request | from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request | ||||
| from graphrag.general.mind_map_extractor import MindMapExtractor | from graphrag.general.mind_map_extractor import MindMapExtractor | ||||
| from rag.app.tag import label_question | from rag.app.tag import label_question |
| from flask_login import login_required, current_user | from flask_login import login_required, current_user | ||||
| from api.db.services.dialog_service import DialogService | from api.db.services.dialog_service import DialogService | ||||
| from api.db import StatusEnum | from api.db import StatusEnum | ||||
| from api.db.services.llm_service import TenantLLMService | |||||
| from api.db.services.tenant_llm_service import TenantLLMService | |||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.user_service import TenantService, UserTenantService | from api.db.services.user_service import TenantService, UserTenantService | ||||
| from api import settings | from api import settings |
| import json | import json | ||||
| from flask import request | from flask import request | ||||
| from flask_login import login_required, current_user | from flask_login import login_required, current_user | ||||
| from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService | |||||
| from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService | |||||
| from api.db.services.llm_service import LLMService | |||||
| from api import settings | from api import settings | ||||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | ||||
| from api.db import StatusEnum, LLMType | from api.db import StatusEnum, LLMType |
| from api.db import StatusEnum | from api.db import StatusEnum | ||||
| from api.db.services.dialog_service import DialogService | from api.db.services.dialog_service import DialogService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import TenantLLMService | |||||
| from api.db.services.tenant_llm_service import TenantLLMService | |||||
| from api.db.services.user_service import TenantService | from api.db.services.user_service import TenantService | ||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required | from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required |
| from api.db.services.file2document_service import File2DocumentService | from api.db.services.file2document_service import File2DocumentService | ||||
| from api.db.services.file_service import FileService | from api.db.services.file_service import FileService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMBundle, TenantLLMService | |||||
| from api.db.services.llm_service import LLMBundle | |||||
| from api.db.services.tenant_llm_service import TenantLLMService | |||||
| from api.db.services.task_service import TaskService, queue_tasks | from api.db.services.task_service import TaskService, queue_tasks | ||||
| from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required | from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required | ||||
| from rag.app.qa import beAdoc, rmPrefix | from rag.app.qa import beAdoc, rmPrefix |
| import json | import json | ||||
| import re | import re | ||||
| import time | import time | ||||
| import tiktoken | import tiktoken | ||||
| from flask import Response, jsonify, request | from flask import Response, jsonify, request | ||||
| from agent.canvas import Canvas | from agent.canvas import Canvas | ||||
| from api.db import LLMType, StatusEnum | from api.db import LLMType, StatusEnum | ||||
| from api.db.db_models import API4Conversation, APIToken | |||||
| from api.db.db_models import APIToken | |||||
| from api.db.services.api_service import API4ConversationService | from api.db.services.api_service import API4ConversationService | ||||
| from api.db.services.canvas_service import UserCanvasService, completionOpenAI | from api.db.services.canvas_service import UserCanvasService, completionOpenAI | ||||
| from api.db.services.canvas_service import completion as agent_completion | from api.db.services.canvas_service import completion as agent_completion | ||||
| from api.db.services.conversation_service import ConversationService, iframe_completion | from api.db.services.conversation_service import ConversationService, iframe_completion | ||||
| from api.db.services.conversation_service import completion as rag_completion | from api.db.services.conversation_service import completion as rag_completion | ||||
| from api.db.services.dialog_service import DialogService, ask, chat | from api.db.services.dialog_service import DialogService, ask, chat | ||||
| from api.db.services.file_service import FileService | |||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.utils import get_uuid | from api.utils import get_uuid |
| from api.db import FileType, UserTenantRole | from api.db import FileType, UserTenantRole | ||||
| from api.db.db_models import TenantLLM | from api.db.db_models import TenantLLM | ||||
| from api.db.services.file_service import FileService | from api.db.services.file_service import FileService | ||||
| from api.db.services.llm_service import LLMService, TenantLLMService | |||||
| from api.db.services.llm_service import TenantLLMService, get_init_tenant_llm | |||||
| from api.db.services.user_service import TenantService, UserService, UserTenantService | from api.db.services.user_service import TenantService, UserService, UserTenantService | ||||
| from api.utils import ( | from api.utils import ( | ||||
| current_timestamp, | current_timestamp, | ||||
| "size": 0, | "size": 0, | ||||
| "location": "", | "location": "", | ||||
| } | } | ||||
| tenant_llm = [] | |||||
| seen = set() | |||||
| factory_configs = [] | |||||
| for factory_config in [ | |||||
| settings.CHAT_CFG, | |||||
| settings.EMBEDDING_CFG, | |||||
| settings.ASR_CFG, | |||||
| settings.IMAGE2TEXT_CFG, | |||||
| settings.RERANK_CFG, | |||||
| ]: | |||||
| factory_name = factory_config["factory"] | |||||
| if factory_name not in seen: | |||||
| seen.add(factory_name) | |||||
| factory_configs.append(factory_config) | |||||
| for factory_config in factory_configs: | |||||
| for llm in LLMService.query(fid=factory_config["factory"]): | |||||
| tenant_llm.append( | |||||
| { | |||||
| "tenant_id": user_id, | |||||
| "llm_factory": factory_config["factory"], | |||||
| "llm_name": llm.llm_name, | |||||
| "model_type": llm.model_type, | |||||
| "api_key": factory_config["api_key"], | |||||
| "api_base": factory_config["base_url"], | |||||
| "max_tokens": llm.max_tokens if llm.max_tokens else 8192, | |||||
| } | |||||
| ) | |||||
| if settings.LIGHTEN != 1: | |||||
| for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS: | |||||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model) | |||||
| tenant_llm.append( | |||||
| { | |||||
| "tenant_id": user_id, | |||||
| "llm_factory": fid, | |||||
| "llm_name": mdlnm, | |||||
| "model_type": "embedding", | |||||
| "api_key": "", | |||||
| "api_base": "", | |||||
| "max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512, | |||||
| } | |||||
| ) | |||||
| unique = {} | |||||
| for item in tenant_llm: | |||||
| key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) | |||||
| if key not in unique: | |||||
| unique[key] = item | |||||
| tenant_llm = list(unique.values()) | |||||
| tenant_llm = get_init_tenant_llm(user_id) | |||||
| if not UserService.save(**user): | if not UserService.save(**user): | ||||
| return | return |
| from api.db.services.canvas_service import CanvasTemplateService | from api.db.services.canvas_service import CanvasTemplateService | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle | |||||
| from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService | |||||
| from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm | |||||
| from api.db.services.user_service import TenantService, UserTenantService | from api.db.services.user_service import TenantService, UserTenantService | ||||
| from api import settings | from api import settings | ||||
| from api.utils.file_utils import get_project_base_directory | from api.utils.file_utils import get_project_base_directory | ||||
| "role": UserTenantRole.OWNER | "role": UserTenantRole.OWNER | ||||
| } | } | ||||
| user_id = user_info | |||||
| tenant_llm = [] | |||||
| seen = set() | |||||
| factory_configs = [] | |||||
| for factory_config in [ | |||||
| settings.CHAT_CFG["factory"], | |||||
| settings.EMBEDDING_CFG["factory"], | |||||
| settings.ASR_CFG["factory"], | |||||
| settings.IMAGE2TEXT_CFG["factory"], | |||||
| settings.RERANK_CFG["factory"], | |||||
| ]: | |||||
| factory_name = factory_config["factory"] | |||||
| if factory_name not in seen: | |||||
| seen.add(factory_name) | |||||
| factory_configs.append(factory_config) | |||||
| for factory_config in factory_configs: | |||||
| for llm in LLMService.query(fid=factory_config["factory"]): | |||||
| tenant_llm.append( | |||||
| { | |||||
| "tenant_id": user_id, | |||||
| "llm_factory": factory_config["factory"], | |||||
| "llm_name": llm.llm_name, | |||||
| "model_type": llm.model_type, | |||||
| "api_key": factory_config["api_key"], | |||||
| "api_base": factory_config["base_url"], | |||||
| "max_tokens": llm.max_tokens if llm.max_tokens else 8192, | |||||
| } | |||||
| ) | |||||
| unique = {} | |||||
| for item in tenant_llm: | |||||
| key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) | |||||
| if key not in unique: | |||||
| unique[key] = item | |||||
| tenant_llm = list(unique.values()) | |||||
| tenant_llm = get_init_tenant_llm(user_info["id"]) | |||||
| if not UserService.save(**user_info): | if not UserService.save(**user_info): | ||||
| logging.error("can't init admin.") | logging.error("can't init admin.") |
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.langfuse_service import TenantLangfuseService | from api.db.services.langfuse_service import TenantLangfuseService | ||||
| from api.db.services.llm_service import LLMBundle, TenantLLMService | |||||
| from api.db.services.llm_service import LLMBundle | |||||
| from api.db.services.tenant_llm_service import TenantLLMService | |||||
| from api.utils import current_timestamp, datetime_format | from api.utils import current_timestamp, datetime_format | ||||
| from rag.app.resume import forbidden_select_fields4resume | from rag.app.resume import forbidden_select_fields4resume | ||||
| from rag.app.tag import label_question | from rag.app.tag import label_question |
| import re | import re | ||||
| from functools import partial | from functools import partial | ||||
| from typing import Generator | from typing import Generator | ||||
| from langfuse import Langfuse | |||||
| from api import settings | |||||
| from api.db import LLMType | |||||
| from api.db.db_models import DB, LLM, LLMFactories, TenantLLM | |||||
| from api.db.db_models import LLM | |||||
| from api.db.services.common_service import CommonService | from api.db.services.common_service import CommonService | ||||
| from api.db.services.langfuse_service import TenantLangfuseService | |||||
| from api.db.services.user_service import TenantService | |||||
| from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel | |||||
| class LLMFactoriesService(CommonService): | |||||
| model = LLMFactories | |||||
| from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService | |||||
| class LLMService(CommonService): | class LLMService(CommonService): | ||||
| model = LLM | model = LLM | ||||
| class TenantLLMService(CommonService): | |||||
| model = TenantLLM | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_api_key(cls, tenant_id, model_name): | |||||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name) | |||||
| if not fid: | |||||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm) | |||||
| else: | |||||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) | |||||
| if (not objs) and fid: | |||||
| if fid == "LocalAI": | |||||
| mdlnm += "___LocalAI" | |||||
| elif fid == "HuggingFace": | |||||
| mdlnm += "___HuggingFace" | |||||
| elif fid == "OpenAI-API-Compatible": | |||||
| mdlnm += "___OpenAI-API" | |||||
| elif fid == "VLLM": | |||||
| mdlnm += "___VLLM" | |||||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) | |||||
| if not objs: | |||||
| return | |||||
| return objs[0] | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_my_llms(cls, tenant_id): | |||||
| fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens] | |||||
| objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts() | |||||
| return list(objs) | |||||
| @staticmethod | |||||
| def split_model_name_and_factory(model_name): | |||||
| arr = model_name.split("@") | |||||
| if len(arr) < 2: | |||||
| return model_name, None | |||||
| if len(arr) > 2: | |||||
| return "@".join(arr[0:-1]), arr[-1] | |||||
| # model name must be xxx@yyy | |||||
| try: | |||||
| model_factories = settings.FACTORY_LLM_INFOS | |||||
| model_providers = set([f["name"] for f in model_factories]) | |||||
| if arr[-1] not in model_providers: | |||||
| return model_name, None | |||||
| return arr[0], arr[-1] | |||||
| except Exception as e: | |||||
| logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}") | |||||
| return model_name, None | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_model_config(cls, tenant_id, llm_type, llm_name=None): | |||||
| e, tenant = TenantService.get_by_id(tenant_id) | |||||
| if not e: | |||||
| raise LookupError("Tenant not found") | |||||
| if llm_type == LLMType.EMBEDDING.value: | |||||
| mdlnm = tenant.embd_id if not llm_name else llm_name | |||||
| elif llm_type == LLMType.SPEECH2TEXT.value: | |||||
| mdlnm = tenant.asr_id | |||||
| elif llm_type == LLMType.IMAGE2TEXT.value: | |||||
| mdlnm = tenant.img2txt_id if not llm_name else llm_name | |||||
| elif llm_type == LLMType.CHAT.value: | |||||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||||
| elif llm_type == LLMType.RERANK: | |||||
| mdlnm = tenant.rerank_id if not llm_name else llm_name | |||||
| elif llm_type == LLMType.TTS: | |||||
| mdlnm = tenant.tts_id if not llm_name else llm_name | |||||
| else: | |||||
| assert False, "LLM type error" | |||||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm) | |||||
| if not model_config: # for some cases seems fid mismatch | |||||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||||
| if model_config: | |||||
| model_config = model_config.to_dict() | |||||
| llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) | |||||
| if not llm and fid: # for some cases seems fid mismatch | |||||
| llm = LLMService.query(llm_name=mdlnm) | |||||
| if llm: | |||||
| model_config["is_tools"] = llm[0].is_tools | |||||
| if not model_config: | |||||
| if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: | |||||
| llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) | |||||
| if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]: | |||||
| model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""} | |||||
| if not model_config: | |||||
| if mdlnm == "flag-embedding": | |||||
| model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""} | |||||
| else: | |||||
| if not mdlnm: | |||||
| raise LookupError(f"Type of {llm_type} model is not set.") | |||||
| raise LookupError("Model({}) not authorized".format(mdlnm)) | |||||
| return model_config | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): | |||||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||||
| kwargs.update({"provider": model_config["llm_factory"]}) | |||||
| if llm_type == LLMType.EMBEDDING.value: | |||||
| if model_config["llm_factory"] not in EmbeddingModel: | |||||
| return | |||||
| return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | |||||
| if llm_type == LLMType.RERANK: | |||||
| if model_config["llm_factory"] not in RerankModel: | |||||
| return | |||||
| return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | |||||
| if llm_type == LLMType.IMAGE2TEXT.value: | |||||
| if model_config["llm_factory"] not in CvModel: | |||||
| return | |||||
| return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) | |||||
| if llm_type == LLMType.CHAT.value: | |||||
| if model_config["llm_factory"] not in ChatModel: | |||||
| return | |||||
| return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs) | |||||
| if llm_type == LLMType.SPEECH2TEXT: | |||||
| if model_config["llm_factory"] not in Seq2txtModel: | |||||
| return | |||||
| return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) | |||||
| if llm_type == LLMType.TTS: | |||||
| if model_config["llm_factory"] not in TTSModel: | |||||
| return | |||||
| return TTSModel[model_config["llm_factory"]]( | |||||
| model_config["api_key"], | |||||
| model_config["llm_name"], | |||||
| base_url=model_config["api_base"], | |||||
| def get_init_tenant_llm(user_id): | |||||
| from api import settings | |||||
| tenant_llm = [] | |||||
| seen = set() | |||||
| factory_configs = [] | |||||
| for factory_config in [ | |||||
| settings.CHAT_CFG, | |||||
| settings.EMBEDDING_CFG, | |||||
| settings.ASR_CFG, | |||||
| settings.IMAGE2TEXT_CFG, | |||||
| settings.RERANK_CFG, | |||||
| ]: | |||||
| factory_name = factory_config["factory"] | |||||
| if factory_name not in seen: | |||||
| seen.add(factory_name) | |||||
| factory_configs.append(factory_config) | |||||
| for factory_config in factory_configs: | |||||
| for llm in LLMService.query(fid=factory_config["factory"]): | |||||
| tenant_llm.append( | |||||
| { | |||||
| "tenant_id": user_id, | |||||
| "llm_factory": factory_config["factory"], | |||||
| "llm_name": llm.llm_name, | |||||
| "model_type": llm.model_type, | |||||
| "api_key": factory_config["api_key"], | |||||
| "api_base": factory_config["base_url"], | |||||
| "max_tokens": llm.max_tokens if llm.max_tokens else 8192, | |||||
| } | |||||
| ) | ) | ||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): | |||||
| e, tenant = TenantService.get_by_id(tenant_id) | |||||
| if not e: | |||||
| logging.error(f"Tenant not found: {tenant_id}") | |||||
| return 0 | |||||
| llm_map = { | |||||
| LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name, | |||||
| LLMType.SPEECH2TEXT.value: tenant.asr_id, | |||||
| LLMType.IMAGE2TEXT.value: tenant.img2txt_id, | |||||
| LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name, | |||||
| LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name, | |||||
| LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name, | |||||
| } | |||||
| mdlnm = llm_map.get(llm_type) | |||||
| if mdlnm is None: | |||||
| logging.error(f"LLM type error: {llm_type}") | |||||
| return 0 | |||||
| llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm) | |||||
| try: | |||||
| num = ( | |||||
| cls.model.update(used_tokens=cls.model.used_tokens + used_tokens) | |||||
| .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True) | |||||
| .execute() | |||||
| if settings.LIGHTEN != 1: | |||||
| for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS: | |||||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model) | |||||
| tenant_llm.append( | |||||
| { | |||||
| "tenant_id": user_id, | |||||
| "llm_factory": fid, | |||||
| "llm_name": mdlnm, | |||||
| "model_type": "embedding", | |||||
| "api_key": "", | |||||
| "api_base": "", | |||||
| "max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512, | |||||
| } | |||||
| ) | ) | ||||
| except Exception: | |||||
| logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name) | |||||
| return 0 | |||||
| return num | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_openai_models(cls): | |||||
| objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() | |||||
| return list(objs) | |||||
| @staticmethod | |||||
| def llm_id2llm_type(llm_id: str) -> str | None: | |||||
| llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) | |||||
| llm_factories = settings.FACTORY_LLM_INFOS | |||||
| for llm_factory in llm_factories: | |||||
| for llm in llm_factory["llm"]: | |||||
| if llm_id == llm["llm_name"]: | |||||
| return llm["model_type"].split(",")[-1] | |||||
| for llm in LLMService.query(llm_name=llm_id): | |||||
| return llm.model_type | |||||
| llm = TenantLLMService.get_or_none(llm_name=llm_id) | |||||
| if llm: | |||||
| return llm.model_type | |||||
| for llm in TenantLLMService.query(llm_name=llm_id): | |||||
| return llm.model_type | |||||
| unique = {} | |||||
| for item in tenant_llm: | |||||
| key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) | |||||
| if key not in unique: | |||||
| unique[key] = item | |||||
| return list(unique.values()) | |||||
| class LLMBundle: | |||||
| class LLMBundle(LLM4Tenant): | |||||
| def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): | def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): | ||||
| self.tenant_id = tenant_id | |||||
| self.llm_type = llm_type | |||||
| self.llm_name = llm_name | |||||
| self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs) | |||||
| assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name) | |||||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||||
| self.max_length = model_config.get("max_tokens", 8192) | |||||
| self.is_tools = model_config.get("is_tools", False) | |||||
| self.verbose_tool_use = kwargs.get("verbose_tool_use") | |||||
| langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id) | |||||
| self.langfuse = None | |||||
| if langfuse_keys: | |||||
| langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) | |||||
| if langfuse.auth_check(): | |||||
| self.langfuse = langfuse | |||||
| trace_id = self.langfuse.create_trace_id() | |||||
| self.trace_context = {"trace_id": trace_id} | |||||
| super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs) | |||||
| def bind_tools(self, toolcall_session, tools): | def bind_tools(self, toolcall_session, tools): | ||||
| if not self.is_tools: | if not self.is_tools: |
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import logging | |||||
| from langfuse import Langfuse | |||||
| from api import settings | |||||
| from api.db import LLMType | |||||
| from api.db.db_models import DB, LLMFactories, TenantLLM | |||||
| from api.db.services.common_service import CommonService | |||||
| from api.db.services.langfuse_service import TenantLangfuseService | |||||
| from api.db.services.user_service import TenantService | |||||
| from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel | |||||
| class LLMFactoriesService(CommonService): | |||||
| model = LLMFactories | |||||
| class TenantLLMService(CommonService): | |||||
| model = TenantLLM | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_api_key(cls, tenant_id, model_name): | |||||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name) | |||||
| if not fid: | |||||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm) | |||||
| else: | |||||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) | |||||
| if (not objs) and fid: | |||||
| if fid == "LocalAI": | |||||
| mdlnm += "___LocalAI" | |||||
| elif fid == "HuggingFace": | |||||
| mdlnm += "___HuggingFace" | |||||
| elif fid == "OpenAI-API-Compatible": | |||||
| mdlnm += "___OpenAI-API" | |||||
| elif fid == "VLLM": | |||||
| mdlnm += "___VLLM" | |||||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) | |||||
| if not objs: | |||||
| return | |||||
| return objs[0] | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_my_llms(cls, tenant_id): | |||||
| fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens] | |||||
| objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts() | |||||
| return list(objs) | |||||
| @staticmethod | |||||
| def split_model_name_and_factory(model_name): | |||||
| arr = model_name.split("@") | |||||
| if len(arr) < 2: | |||||
| return model_name, None | |||||
| if len(arr) > 2: | |||||
| return "@".join(arr[0:-1]), arr[-1] | |||||
| # model name must be xxx@yyy | |||||
| try: | |||||
| model_factories = settings.FACTORY_LLM_INFOS | |||||
| model_providers = set([f["name"] for f in model_factories]) | |||||
| if arr[-1] not in model_providers: | |||||
| return model_name, None | |||||
| return arr[0], arr[-1] | |||||
| except Exception as e: | |||||
| logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}") | |||||
| return model_name, None | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_model_config(cls, tenant_id, llm_type, llm_name=None): | |||||
| from api.db.services.llm_service import LLMService | |||||
| e, tenant = TenantService.get_by_id(tenant_id) | |||||
| if not e: | |||||
| raise LookupError("Tenant not found") | |||||
| if llm_type == LLMType.EMBEDDING.value: | |||||
| mdlnm = tenant.embd_id if not llm_name else llm_name | |||||
| elif llm_type == LLMType.SPEECH2TEXT.value: | |||||
| mdlnm = tenant.asr_id | |||||
| elif llm_type == LLMType.IMAGE2TEXT.value: | |||||
| mdlnm = tenant.img2txt_id if not llm_name else llm_name | |||||
| elif llm_type == LLMType.CHAT.value: | |||||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||||
| elif llm_type == LLMType.RERANK: | |||||
| mdlnm = tenant.rerank_id if not llm_name else llm_name | |||||
| elif llm_type == LLMType.TTS: | |||||
| mdlnm = tenant.tts_id if not llm_name else llm_name | |||||
| else: | |||||
| assert False, "LLM type error" | |||||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm) | |||||
| if not model_config: # for some cases seems fid mismatch | |||||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||||
| if model_config: | |||||
| model_config = model_config.to_dict() | |||||
| llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) | |||||
| if not llm and fid: # for some cases seems fid mismatch | |||||
| llm = LLMService.query(llm_name=mdlnm) | |||||
| if llm: | |||||
| model_config["is_tools"] = llm[0].is_tools | |||||
| if not model_config: | |||||
| if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: | |||||
| llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) | |||||
| if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]: | |||||
| model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""} | |||||
| if not model_config: | |||||
| if mdlnm == "flag-embedding": | |||||
| model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""} | |||||
| else: | |||||
| if not mdlnm: | |||||
| raise LookupError(f"Type of {llm_type} model is not set.") | |||||
| raise LookupError("Model({}) not authorized".format(mdlnm)) | |||||
| return model_config | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): | |||||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||||
| kwargs.update({"provider": model_config["llm_factory"]}) | |||||
| if llm_type == LLMType.EMBEDDING.value: | |||||
| if model_config["llm_factory"] not in EmbeddingModel: | |||||
| return | |||||
| return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | |||||
| if llm_type == LLMType.RERANK: | |||||
| if model_config["llm_factory"] not in RerankModel: | |||||
| return | |||||
| return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | |||||
| if llm_type == LLMType.IMAGE2TEXT.value: | |||||
| if model_config["llm_factory"] not in CvModel: | |||||
| return | |||||
| return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) | |||||
| if llm_type == LLMType.CHAT.value: | |||||
| if model_config["llm_factory"] not in ChatModel: | |||||
| return | |||||
| return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs) | |||||
| if llm_type == LLMType.SPEECH2TEXT: | |||||
| if model_config["llm_factory"] not in Seq2txtModel: | |||||
| return | |||||
| return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) | |||||
| if llm_type == LLMType.TTS: | |||||
| if model_config["llm_factory"] not in TTSModel: | |||||
| return | |||||
| return TTSModel[model_config["llm_factory"]]( | |||||
| model_config["api_key"], | |||||
| model_config["llm_name"], | |||||
| base_url=model_config["api_base"], | |||||
| ) | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): | |||||
| e, tenant = TenantService.get_by_id(tenant_id) | |||||
| if not e: | |||||
| logging.error(f"Tenant not found: {tenant_id}") | |||||
| return 0 | |||||
| llm_map = { | |||||
| LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name, | |||||
| LLMType.SPEECH2TEXT.value: tenant.asr_id, | |||||
| LLMType.IMAGE2TEXT.value: tenant.img2txt_id, | |||||
| LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name, | |||||
| LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name, | |||||
| LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name, | |||||
| } | |||||
| mdlnm = llm_map.get(llm_type) | |||||
| if mdlnm is None: | |||||
| logging.error(f"LLM type error: {llm_type}") | |||||
| return 0 | |||||
| llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm) | |||||
| try: | |||||
| num = ( | |||||
| cls.model.update(used_tokens=cls.model.used_tokens + used_tokens) | |||||
| .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True) | |||||
| .execute() | |||||
| ) | |||||
| except Exception: | |||||
| logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name) | |||||
| return 0 | |||||
| return num | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_openai_models(cls): | |||||
| objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() | |||||
| return list(objs) | |||||
| @staticmethod | |||||
| def llm_id2llm_type(llm_id: str) -> str | None: | |||||
| from api.db.services.llm_service import LLMService | |||||
| llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) | |||||
| llm_factories = settings.FACTORY_LLM_INFOS | |||||
| for llm_factory in llm_factories: | |||||
| for llm in llm_factory["llm"]: | |||||
| if llm_id == llm["llm_name"]: | |||||
| return llm["model_type"].split(",")[-1] | |||||
| for llm in LLMService.query(llm_name=llm_id): | |||||
| return llm.model_type | |||||
| llm = TenantLLMService.get_or_none(llm_name=llm_id) | |||||
| if llm: | |||||
| return llm.model_type | |||||
| for llm in TenantLLMService.query(llm_name=llm_id): | |||||
| return llm.model_type | |||||
| class LLM4Tenant: | |||||
| def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): | |||||
| self.tenant_id = tenant_id | |||||
| self.llm_type = llm_type | |||||
| self.llm_name = llm_name | |||||
| self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs) | |||||
| assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name) | |||||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||||
| self.max_length = model_config.get("max_tokens", 8192) | |||||
| self.is_tools = model_config.get("is_tools", False) | |||||
| self.verbose_tool_use = kwargs.get("verbose_tool_use") | |||||
| langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id) | |||||
| self.langfuse = None | |||||
| if langfuse_keys: | |||||
| langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) | |||||
| if langfuse.auth_check(): | |||||
| self.langfuse = langfuse | |||||
| trace_id = self.langfuse.create_trace_id() | |||||
| self.trace_context = {"trace_id": trace_id} |
| from api import settings | from api import settings | ||||
| from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC | from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC | ||||
| from api.db.db_models import APIToken | from api.db.db_models import APIToken | ||||
| from api.db.services.llm_service import LLMService, TenantLLMService | |||||
| from api.db.services.llm_service import LLMService | |||||
| from api.db.services.tenant_llm_service import TenantLLMService | |||||
| from api.utils import CustomJSONEncoder, get_uuid, json_dumps | from api.utils import CustomJSONEncoder, get_uuid, json_dumps | ||||
| from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions | from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions | ||||
| def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): | def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): | ||||
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.db.services.llm_service import TenantLLMService | |||||
| from api.db.services.tenant_llm_service import TenantLLMService | |||||
| if not chat_mdl: | if not chat_mdl: | ||||
| if TenantLLMService.llm_id2llm_type(llm_id) == "image2text": | if TenantLLMService.llm_id2llm_type(llm_id) == "image2text": | ||||
| def cross_languages(tenant_id, llm_id, query, languages=[]): | def cross_languages(tenant_id, llm_id, query, languages=[]): | ||||
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.db.services.llm_service import TenantLLMService | |||||
| from api.db.services.tenant_llm_service import TenantLLMService | |||||
| if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text": | if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text": | ||||
| chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) | chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) |