### What problem does this PR solve? ### Type of change - [x] Refactoringtags/v0.20.2
| @@ -24,7 +24,8 @@ from typing import Any | |||
| import json_repair | |||
| from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta | |||
| from api.db.services.llm_service import LLMBundle, TenantLLMService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.tenant_llm_service import TenantLLMService | |||
| from api.db.services.mcp_server_service import MCPServerService | |||
| from api.utils.api_utils import timeout | |||
| from rag.prompts import message_fit_in | |||
| @@ -24,7 +24,8 @@ from copy import deepcopy | |||
| from functools import partial | |||
| from api.db import LLMType | |||
| from api.db.services.llm_service import LLMBundle, TenantLLMService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.tenant_llm_service import TenantLLMService | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| from api.utils.api_utils import timeout | |||
| from rag.prompts import message_fit_in, citation_prompt | |||
| @@ -28,8 +28,8 @@ from api.db.db_models import APIToken | |||
| from api.db.services.conversation_service import ConversationService, structure_answer | |||
| from api.db.services.dialog_service import DialogService, ask, chat | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMBundle, TenantService | |||
| from api.db.services.user_service import UserTenantService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.user_service import UserTenantService, TenantService | |||
| from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request | |||
| from graphrag.general.mind_map_extractor import MindMapExtractor | |||
| from rag.app.tag import label_question | |||
| @@ -18,7 +18,7 @@ from flask import request | |||
| from flask_login import login_required, current_user | |||
| from api.db.services.dialog_service import DialogService | |||
| from api.db import StatusEnum | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.tenant_llm_service import TenantLLMService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.user_service import TenantService, UserTenantService | |||
| from api import settings | |||
| @@ -17,7 +17,8 @@ import logging | |||
| import json | |||
| from flask import request | |||
| from flask_login import login_required, current_user | |||
| from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService | |||
| from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService | |||
| from api.db.services.llm_service import LLMService | |||
| from api import settings | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.db import StatusEnum, LLMType | |||
| @@ -21,7 +21,7 @@ from api import settings | |||
| from api.db import StatusEnum | |||
| from api.db.services.dialog_service import DialogService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.tenant_llm_service import TenantLLMService | |||
| from api.db.services.user_service import TenantService | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required | |||
| @@ -32,7 +32,8 @@ from api.db.services.document_service import DocumentService | |||
| from api.db.services.file2document_service import File2DocumentService | |||
| from api.db.services.file_service import FileService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMBundle, TenantLLMService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.tenant_llm_service import TenantLLMService | |||
| from api.db.services.task_service import TaskService, queue_tasks | |||
| from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required | |||
| from rag.app.qa import beAdoc, rmPrefix | |||
| @@ -16,20 +16,17 @@ | |||
| import json | |||
| import re | |||
| import time | |||
| import tiktoken | |||
| from flask import Response, jsonify, request | |||
| from agent.canvas import Canvas | |||
| from api.db import LLMType, StatusEnum | |||
| from api.db.db_models import API4Conversation, APIToken | |||
| from api.db.db_models import APIToken | |||
| from api.db.services.api_service import API4ConversationService | |||
| from api.db.services.canvas_service import UserCanvasService, completionOpenAI | |||
| from api.db.services.canvas_service import completion as agent_completion | |||
| from api.db.services.conversation_service import ConversationService, iframe_completion | |||
| from api.db.services.conversation_service import completion as rag_completion | |||
| from api.db.services.dialog_service import DialogService, ask, chat | |||
| from api.db.services.file_service import FileService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.utils import get_uuid | |||
| @@ -28,7 +28,7 @@ from api.apps.auth import get_auth_client | |||
| from api.db import FileType, UserTenantRole | |||
| from api.db.db_models import TenantLLM | |||
| from api.db.services.file_service import FileService | |||
| from api.db.services.llm_service import LLMService, TenantLLMService | |||
| from api.db.services.llm_service import TenantLLMService, get_init_tenant_llm | |||
| from api.db.services.user_service import TenantService, UserService, UserTenantService | |||
| from api.utils import ( | |||
| current_timestamp, | |||
| @@ -619,57 +619,8 @@ def user_register(user_id, user): | |||
| "size": 0, | |||
| "location": "", | |||
| } | |||
| tenant_llm = [] | |||
| seen = set() | |||
| factory_configs = [] | |||
| for factory_config in [ | |||
| settings.CHAT_CFG, | |||
| settings.EMBEDDING_CFG, | |||
| settings.ASR_CFG, | |||
| settings.IMAGE2TEXT_CFG, | |||
| settings.RERANK_CFG, | |||
| ]: | |||
| factory_name = factory_config["factory"] | |||
| if factory_name not in seen: | |||
| seen.add(factory_name) | |||
| factory_configs.append(factory_config) | |||
| for factory_config in factory_configs: | |||
| for llm in LLMService.query(fid=factory_config["factory"]): | |||
| tenant_llm.append( | |||
| { | |||
| "tenant_id": user_id, | |||
| "llm_factory": factory_config["factory"], | |||
| "llm_name": llm.llm_name, | |||
| "model_type": llm.model_type, | |||
| "api_key": factory_config["api_key"], | |||
| "api_base": factory_config["base_url"], | |||
| "max_tokens": llm.max_tokens if llm.max_tokens else 8192, | |||
| } | |||
| ) | |||
| if settings.LIGHTEN != 1: | |||
| for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS: | |||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model) | |||
| tenant_llm.append( | |||
| { | |||
| "tenant_id": user_id, | |||
| "llm_factory": fid, | |||
| "llm_name": mdlnm, | |||
| "model_type": "embedding", | |||
| "api_key": "", | |||
| "api_base": "", | |||
| "max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512, | |||
| } | |||
| ) | |||
| unique = {} | |||
| for item in tenant_llm: | |||
| key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) | |||
| if key not in unique: | |||
| unique[key] = item | |||
| tenant_llm = list(unique.values()) | |||
| tenant_llm = get_init_tenant_llm(user_id) | |||
| if not UserService.save(**user): | |||
| return | |||
| @@ -27,7 +27,8 @@ from api.db.services import UserService | |||
| from api.db.services.canvas_service import CanvasTemplateService | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle | |||
| from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService | |||
| from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm | |||
| from api.db.services.user_service import TenantService, UserTenantService | |||
| from api import settings | |||
| from api.utils.file_utils import get_project_base_directory | |||
| @@ -64,43 +65,7 @@ def init_superuser(): | |||
| "role": UserTenantRole.OWNER | |||
| } | |||
| user_id = user_info | |||
| tenant_llm = [] | |||
| seen = set() | |||
| factory_configs = [] | |||
| for factory_config in [ | |||
| settings.CHAT_CFG["factory"], | |||
| settings.EMBEDDING_CFG["factory"], | |||
| settings.ASR_CFG["factory"], | |||
| settings.IMAGE2TEXT_CFG["factory"], | |||
| settings.RERANK_CFG["factory"], | |||
| ]: | |||
| factory_name = factory_config["factory"] | |||
| if factory_name not in seen: | |||
| seen.add(factory_name) | |||
| factory_configs.append(factory_config) | |||
| for factory_config in factory_configs: | |||
| for llm in LLMService.query(fid=factory_config["factory"]): | |||
| tenant_llm.append( | |||
| { | |||
| "tenant_id": user_id, | |||
| "llm_factory": factory_config["factory"], | |||
| "llm_name": llm.llm_name, | |||
| "model_type": llm.model_type, | |||
| "api_key": factory_config["api_key"], | |||
| "api_base": factory_config["base_url"], | |||
| "max_tokens": llm.max_tokens if llm.max_tokens else 8192, | |||
| } | |||
| ) | |||
| unique = {} | |||
| for item in tenant_llm: | |||
| key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) | |||
| if key not in unique: | |||
| unique[key] = item | |||
| tenant_llm = list(unique.values()) | |||
| tenant_llm = get_init_tenant_llm(user_info["id"]) | |||
| if not UserService.save(**user_info): | |||
| logging.error("can't init admin.") | |||
| @@ -33,7 +33,8 @@ from api.db.services.common_service import CommonService | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.langfuse_service import TenantLangfuseService | |||
| from api.db.services.llm_service import LLMBundle, TenantLLMService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.tenant_llm_service import TenantLLMService | |||
| from api.utils import current_timestamp, datetime_format | |||
| from rag.app.resume import forbidden_select_fields4resume | |||
| from rag.app.tag import label_question | |||
| @@ -18,246 +18,73 @@ import logging | |||
| import re | |||
| from functools import partial | |||
| from typing import Generator | |||
| from langfuse import Langfuse | |||
| from api import settings | |||
| from api.db import LLMType | |||
| from api.db.db_models import DB, LLM, LLMFactories, TenantLLM | |||
| from api.db.db_models import LLM | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.langfuse_service import TenantLangfuseService | |||
| from api.db.services.user_service import TenantService | |||
| from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel | |||
| class LLMFactoriesService(CommonService): | |||
| model = LLMFactories | |||
| from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService | |||
| class LLMService(CommonService): | |||
| model = LLM | |||
| class TenantLLMService(CommonService): | |||
| model = TenantLLM | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_api_key(cls, tenant_id, model_name): | |||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name) | |||
| if not fid: | |||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm) | |||
| else: | |||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) | |||
| if (not objs) and fid: | |||
| if fid == "LocalAI": | |||
| mdlnm += "___LocalAI" | |||
| elif fid == "HuggingFace": | |||
| mdlnm += "___HuggingFace" | |||
| elif fid == "OpenAI-API-Compatible": | |||
| mdlnm += "___OpenAI-API" | |||
| elif fid == "VLLM": | |||
| mdlnm += "___VLLM" | |||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) | |||
| if not objs: | |||
| return | |||
| return objs[0] | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_my_llms(cls, tenant_id): | |||
| fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens] | |||
| objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts() | |||
| return list(objs) | |||
| @staticmethod | |||
| def split_model_name_and_factory(model_name): | |||
| arr = model_name.split("@") | |||
| if len(arr) < 2: | |||
| return model_name, None | |||
| if len(arr) > 2: | |||
| return "@".join(arr[0:-1]), arr[-1] | |||
| # model name must be xxx@yyy | |||
| try: | |||
| model_factories = settings.FACTORY_LLM_INFOS | |||
| model_providers = set([f["name"] for f in model_factories]) | |||
| if arr[-1] not in model_providers: | |||
| return model_name, None | |||
| return arr[0], arr[-1] | |||
| except Exception as e: | |||
| logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}") | |||
| return model_name, None | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_model_config(cls, tenant_id, llm_type, llm_name=None): | |||
| e, tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| raise LookupError("Tenant not found") | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| mdlnm = tenant.embd_id if not llm_name else llm_name | |||
| elif llm_type == LLMType.SPEECH2TEXT.value: | |||
| mdlnm = tenant.asr_id | |||
| elif llm_type == LLMType.IMAGE2TEXT.value: | |||
| mdlnm = tenant.img2txt_id if not llm_name else llm_name | |||
| elif llm_type == LLMType.CHAT.value: | |||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||
| elif llm_type == LLMType.RERANK: | |||
| mdlnm = tenant.rerank_id if not llm_name else llm_name | |||
| elif llm_type == LLMType.TTS: | |||
| mdlnm = tenant.tts_id if not llm_name else llm_name | |||
| else: | |||
| assert False, "LLM type error" | |||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm) | |||
| if not model_config: # for some cases seems fid mismatch | |||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||
| if model_config: | |||
| model_config = model_config.to_dict() | |||
| llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) | |||
| if not llm and fid: # for some cases seems fid mismatch | |||
| llm = LLMService.query(llm_name=mdlnm) | |||
| if llm: | |||
| model_config["is_tools"] = llm[0].is_tools | |||
| if not model_config: | |||
| if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: | |||
| llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) | |||
| if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]: | |||
| model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""} | |||
| if not model_config: | |||
| if mdlnm == "flag-embedding": | |||
| model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""} | |||
| else: | |||
| if not mdlnm: | |||
| raise LookupError(f"Type of {llm_type} model is not set.") | |||
| raise LookupError("Model({}) not authorized".format(mdlnm)) | |||
| return model_config | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): | |||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||
| kwargs.update({"provider": model_config["llm_factory"]}) | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| if model_config["llm_factory"] not in EmbeddingModel: | |||
| return | |||
| return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | |||
| if llm_type == LLMType.RERANK: | |||
| if model_config["llm_factory"] not in RerankModel: | |||
| return | |||
| return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | |||
| if llm_type == LLMType.IMAGE2TEXT.value: | |||
| if model_config["llm_factory"] not in CvModel: | |||
| return | |||
| return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) | |||
| if llm_type == LLMType.CHAT.value: | |||
| if model_config["llm_factory"] not in ChatModel: | |||
| return | |||
| return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs) | |||
| if llm_type == LLMType.SPEECH2TEXT: | |||
| if model_config["llm_factory"] not in Seq2txtModel: | |||
| return | |||
| return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) | |||
| if llm_type == LLMType.TTS: | |||
| if model_config["llm_factory"] not in TTSModel: | |||
| return | |||
| return TTSModel[model_config["llm_factory"]]( | |||
| model_config["api_key"], | |||
| model_config["llm_name"], | |||
| base_url=model_config["api_base"], | |||
| def get_init_tenant_llm(user_id): | |||
| from api import settings | |||
| tenant_llm = [] | |||
| seen = set() | |||
| factory_configs = [] | |||
| for factory_config in [ | |||
| settings.CHAT_CFG, | |||
| settings.EMBEDDING_CFG, | |||
| settings.ASR_CFG, | |||
| settings.IMAGE2TEXT_CFG, | |||
| settings.RERANK_CFG, | |||
| ]: | |||
| factory_name = factory_config["factory"] | |||
| if factory_name not in seen: | |||
| seen.add(factory_name) | |||
| factory_configs.append(factory_config) | |||
| for factory_config in factory_configs: | |||
| for llm in LLMService.query(fid=factory_config["factory"]): | |||
| tenant_llm.append( | |||
| { | |||
| "tenant_id": user_id, | |||
| "llm_factory": factory_config["factory"], | |||
| "llm_name": llm.llm_name, | |||
| "model_type": llm.model_type, | |||
| "api_key": factory_config["api_key"], | |||
| "api_base": factory_config["base_url"], | |||
| "max_tokens": llm.max_tokens if llm.max_tokens else 8192, | |||
| } | |||
| ) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): | |||
| e, tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| logging.error(f"Tenant not found: {tenant_id}") | |||
| return 0 | |||
| llm_map = { | |||
| LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name, | |||
| LLMType.SPEECH2TEXT.value: tenant.asr_id, | |||
| LLMType.IMAGE2TEXT.value: tenant.img2txt_id, | |||
| LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name, | |||
| LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name, | |||
| LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name, | |||
| } | |||
| mdlnm = llm_map.get(llm_type) | |||
| if mdlnm is None: | |||
| logging.error(f"LLM type error: {llm_type}") | |||
| return 0 | |||
| llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm) | |||
| try: | |||
| num = ( | |||
| cls.model.update(used_tokens=cls.model.used_tokens + used_tokens) | |||
| .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True) | |||
| .execute() | |||
| if settings.LIGHTEN != 1: | |||
| for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS: | |||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model) | |||
| tenant_llm.append( | |||
| { | |||
| "tenant_id": user_id, | |||
| "llm_factory": fid, | |||
| "llm_name": mdlnm, | |||
| "model_type": "embedding", | |||
| "api_key": "", | |||
| "api_base": "", | |||
| "max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512, | |||
| } | |||
| ) | |||
| except Exception: | |||
| logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name) | |||
| return 0 | |||
| return num | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_openai_models(cls): | |||
| objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() | |||
| return list(objs) | |||
| @staticmethod | |||
| def llm_id2llm_type(llm_id: str) -> str | None: | |||
| llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) | |||
| llm_factories = settings.FACTORY_LLM_INFOS | |||
| for llm_factory in llm_factories: | |||
| for llm in llm_factory["llm"]: | |||
| if llm_id == llm["llm_name"]: | |||
| return llm["model_type"].split(",")[-1] | |||
| for llm in LLMService.query(llm_name=llm_id): | |||
| return llm.model_type | |||
| llm = TenantLLMService.get_or_none(llm_name=llm_id) | |||
| if llm: | |||
| return llm.model_type | |||
| for llm in TenantLLMService.query(llm_name=llm_id): | |||
| return llm.model_type | |||
| unique = {} | |||
| for item in tenant_llm: | |||
| key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) | |||
| if key not in unique: | |||
| unique[key] = item | |||
| return list(unique.values()) | |||
| class LLMBundle: | |||
| class LLMBundle(LLM4Tenant): | |||
| def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): | |||
| self.tenant_id = tenant_id | |||
| self.llm_type = llm_type | |||
| self.llm_name = llm_name | |||
| self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs) | |||
| assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name) | |||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||
| self.max_length = model_config.get("max_tokens", 8192) | |||
| self.is_tools = model_config.get("is_tools", False) | |||
| self.verbose_tool_use = kwargs.get("verbose_tool_use") | |||
| langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id) | |||
| self.langfuse = None | |||
| if langfuse_keys: | |||
| langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) | |||
| if langfuse.auth_check(): | |||
| self.langfuse = langfuse | |||
| trace_id = self.langfuse.create_trace_id() | |||
| self.trace_context = {"trace_id": trace_id} | |||
| super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs) | |||
| def bind_tools(self, toolcall_session, tools): | |||
| if not self.is_tools: | |||
| @@ -0,0 +1,252 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| from langfuse import Langfuse | |||
| from api import settings | |||
| from api.db import LLMType | |||
| from api.db.db_models import DB, LLMFactories, TenantLLM | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.langfuse_service import TenantLangfuseService | |||
| from api.db.services.user_service import TenantService | |||
| from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel | |||
| class LLMFactoriesService(CommonService): | |||
| model = LLMFactories | |||
| class TenantLLMService(CommonService): | |||
| model = TenantLLM | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_api_key(cls, tenant_id, model_name): | |||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name) | |||
| if not fid: | |||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm) | |||
| else: | |||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) | |||
| if (not objs) and fid: | |||
| if fid == "LocalAI": | |||
| mdlnm += "___LocalAI" | |||
| elif fid == "HuggingFace": | |||
| mdlnm += "___HuggingFace" | |||
| elif fid == "OpenAI-API-Compatible": | |||
| mdlnm += "___OpenAI-API" | |||
| elif fid == "VLLM": | |||
| mdlnm += "___VLLM" | |||
| objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) | |||
| if not objs: | |||
| return | |||
| return objs[0] | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_my_llms(cls, tenant_id): | |||
| fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens] | |||
| objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts() | |||
| return list(objs) | |||
| @staticmethod | |||
| def split_model_name_and_factory(model_name): | |||
| arr = model_name.split("@") | |||
| if len(arr) < 2: | |||
| return model_name, None | |||
| if len(arr) > 2: | |||
| return "@".join(arr[0:-1]), arr[-1] | |||
| # model name must be xxx@yyy | |||
| try: | |||
| model_factories = settings.FACTORY_LLM_INFOS | |||
| model_providers = set([f["name"] for f in model_factories]) | |||
| if arr[-1] not in model_providers: | |||
| return model_name, None | |||
| return arr[0], arr[-1] | |||
| except Exception as e: | |||
| logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}") | |||
| return model_name, None | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_model_config(cls, tenant_id, llm_type, llm_name=None): | |||
| from api.db.services.llm_service import LLMService | |||
| e, tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| raise LookupError("Tenant not found") | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| mdlnm = tenant.embd_id if not llm_name else llm_name | |||
| elif llm_type == LLMType.SPEECH2TEXT.value: | |||
| mdlnm = tenant.asr_id | |||
| elif llm_type == LLMType.IMAGE2TEXT.value: | |||
| mdlnm = tenant.img2txt_id if not llm_name else llm_name | |||
| elif llm_type == LLMType.CHAT.value: | |||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||
| elif llm_type == LLMType.RERANK: | |||
| mdlnm = tenant.rerank_id if not llm_name else llm_name | |||
| elif llm_type == LLMType.TTS: | |||
| mdlnm = tenant.tts_id if not llm_name else llm_name | |||
| else: | |||
| assert False, "LLM type error" | |||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm) | |||
| if not model_config: # for some cases seems fid mismatch | |||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||
| if model_config: | |||
| model_config = model_config.to_dict() | |||
| llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) | |||
| if not llm and fid: # for some cases seems fid mismatch | |||
| llm = LLMService.query(llm_name=mdlnm) | |||
| if llm: | |||
| model_config["is_tools"] = llm[0].is_tools | |||
| if not model_config: | |||
| if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: | |||
| llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) | |||
| if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]: | |||
| model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""} | |||
| if not model_config: | |||
| if mdlnm == "flag-embedding": | |||
| model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""} | |||
| else: | |||
| if not mdlnm: | |||
| raise LookupError(f"Type of {llm_type} model is not set.") | |||
| raise LookupError("Model({}) not authorized".format(mdlnm)) | |||
| return model_config | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): | |||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||
| kwargs.update({"provider": model_config["llm_factory"]}) | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| if model_config["llm_factory"] not in EmbeddingModel: | |||
| return | |||
| return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | |||
| if llm_type == LLMType.RERANK: | |||
| if model_config["llm_factory"] not in RerankModel: | |||
| return | |||
| return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | |||
| if llm_type == LLMType.IMAGE2TEXT.value: | |||
| if model_config["llm_factory"] not in CvModel: | |||
| return | |||
| return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) | |||
| if llm_type == LLMType.CHAT.value: | |||
| if model_config["llm_factory"] not in ChatModel: | |||
| return | |||
| return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs) | |||
| if llm_type == LLMType.SPEECH2TEXT: | |||
| if model_config["llm_factory"] not in Seq2txtModel: | |||
| return | |||
| return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) | |||
| if llm_type == LLMType.TTS: | |||
| if model_config["llm_factory"] not in TTSModel: | |||
| return | |||
| return TTSModel[model_config["llm_factory"]]( | |||
| model_config["api_key"], | |||
| model_config["llm_name"], | |||
| base_url=model_config["api_base"], | |||
| ) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): | |||
| e, tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| logging.error(f"Tenant not found: {tenant_id}") | |||
| return 0 | |||
| llm_map = { | |||
| LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name, | |||
| LLMType.SPEECH2TEXT.value: tenant.asr_id, | |||
| LLMType.IMAGE2TEXT.value: tenant.img2txt_id, | |||
| LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name, | |||
| LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name, | |||
| LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name, | |||
| } | |||
| mdlnm = llm_map.get(llm_type) | |||
| if mdlnm is None: | |||
| logging.error(f"LLM type error: {llm_type}") | |||
| return 0 | |||
| llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm) | |||
| try: | |||
| num = ( | |||
| cls.model.update(used_tokens=cls.model.used_tokens + used_tokens) | |||
| .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True) | |||
| .execute() | |||
| ) | |||
| except Exception: | |||
| logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name) | |||
| return 0 | |||
| return num | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_openai_models(cls): | |||
| objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() | |||
| return list(objs) | |||
| @staticmethod | |||
| def llm_id2llm_type(llm_id: str) -> str | None: | |||
| from api.db.services.llm_service import LLMService | |||
| llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) | |||
| llm_factories = settings.FACTORY_LLM_INFOS | |||
| for llm_factory in llm_factories: | |||
| for llm in llm_factory["llm"]: | |||
| if llm_id == llm["llm_name"]: | |||
| return llm["model_type"].split(",")[-1] | |||
| for llm in LLMService.query(llm_name=llm_id): | |||
| return llm.model_type | |||
| llm = TenantLLMService.get_or_none(llm_name=llm_id) | |||
| if llm: | |||
| return llm.model_type | |||
| for llm in TenantLLMService.query(llm_name=llm_id): | |||
| return llm.model_type | |||
| class LLM4Tenant: | |||
| def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): | |||
| self.tenant_id = tenant_id | |||
| self.llm_type = llm_type | |||
| self.llm_name = llm_name | |||
| self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs) | |||
| assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name) | |||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||
| self.max_length = model_config.get("max_tokens", 8192) | |||
| self.is_tools = model_config.get("is_tools", False) | |||
| self.verbose_tool_use = kwargs.get("verbose_tool_use") | |||
| langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id) | |||
| self.langfuse = None | |||
| if langfuse_keys: | |||
| langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) | |||
| if langfuse.auth_check(): | |||
| self.langfuse = langfuse | |||
| trace_id = self.langfuse.create_trace_id() | |||
| self.trace_context = {"trace_id": trace_id} | |||
| @@ -48,7 +48,8 @@ from werkzeug.http import HTTP_STATUS_CODES | |||
| from api import settings | |||
| from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC | |||
| from api.db.db_models import APIToken | |||
| from api.db.services.llm_service import LLMService, TenantLLMService | |||
| from api.db.services.llm_service import LLMService | |||
| from api.db.services.tenant_llm_service import TenantLLMService | |||
| from api.utils import CustomJSONEncoder, get_uuid, json_dumps | |||
| from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions | |||
| @@ -197,7 +197,7 @@ def question_proposal(chat_mdl, content, topn=3): | |||
| def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): | |||
| from api.db import LLMType | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.tenant_llm_service import TenantLLMService | |||
| if not chat_mdl: | |||
| if TenantLLMService.llm_id2llm_type(llm_id) == "image2text": | |||
| @@ -231,7 +231,7 @@ def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_ | |||
| def cross_languages(tenant_id, llm_id, query, languages=[]): | |||
| from api.db import LLMType | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.tenant_llm_service import TenantLLMService | |||
| if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text": | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) | |||