### What problem does this PR solve? Support automatic LLMs registration. ### Type of change - [x] Refactoringtags/v0.20.0
| @@ -15,289 +15,53 @@ | |||
| # | |||
| # AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency! | |||
| # | |||
| from .embedding_model import ( | |||
| OllamaEmbed, | |||
| LocalAIEmbed, | |||
| OpenAIEmbed, | |||
| AzureEmbed, | |||
| XinferenceEmbed, | |||
| QWenEmbed, | |||
| ZhipuEmbed, | |||
| FastEmbed, | |||
| YoudaoEmbed, | |||
| BaiChuanEmbed, | |||
| JinaEmbed, | |||
| DefaultEmbedding, | |||
| MistralEmbed, | |||
| BedrockEmbed, | |||
| GeminiEmbed, | |||
| NvidiaEmbed, | |||
| LmStudioEmbed, | |||
| OpenAI_APIEmbed, | |||
| CoHereEmbed, | |||
| TogetherAIEmbed, | |||
| PerfXCloudEmbed, | |||
| UpstageEmbed, | |||
| SILICONFLOWEmbed, | |||
| ReplicateEmbed, | |||
| BaiduYiyanEmbed, | |||
| VoyageEmbed, | |||
| HuggingFaceEmbed, | |||
| VolcEngineEmbed, | |||
| GPUStackEmbed, | |||
| NovitaEmbed, | |||
| GiteeEmbed | |||
| ) | |||
| from .chat_model import ( | |||
| GptTurbo, | |||
| AzureChat, | |||
| ZhipuChat, | |||
| QWenChat, | |||
| OllamaChat, | |||
| LocalAIChat, | |||
| XinferenceChat, | |||
| MoonshotChat, | |||
| DeepSeekChat, | |||
| VolcEngineChat, | |||
| BaiChuanChat, | |||
| MiniMaxChat, | |||
| MistralChat, | |||
| GeminiChat, | |||
| BedrockChat, | |||
| GroqChat, | |||
| OpenRouterChat, | |||
| StepFunChat, | |||
| NvidiaChat, | |||
| LmStudioChat, | |||
| OpenAI_APIChat, | |||
| CoHereChat, | |||
| LeptonAIChat, | |||
| TogetherAIChat, | |||
| PerfXCloudChat, | |||
| UpstageChat, | |||
| NovitaAIChat, | |||
| SILICONFLOWChat, | |||
| PPIOChat, | |||
| YiChat, | |||
| ReplicateChat, | |||
| HunyuanChat, | |||
| SparkChat, | |||
| BaiduYiyanChat, | |||
| AnthropicChat, | |||
| GoogleChat, | |||
| HuggingFaceChat, | |||
| GPUStackChat, | |||
| ModelScopeChat, | |||
| GiteeChat | |||
| ) | |||
| from .cv_model import ( | |||
| GptV4, | |||
| AzureGptV4, | |||
| OllamaCV, | |||
| XinferenceCV, | |||
| QWenCV, | |||
| Zhipu4V, | |||
| LocalCV, | |||
| GeminiCV, | |||
| OpenRouterCV, | |||
| LocalAICV, | |||
| NvidiaCV, | |||
| LmStudioCV, | |||
| StepFunCV, | |||
| OpenAI_APICV, | |||
| TogetherAICV, | |||
| YiCV, | |||
| HunyuanCV, | |||
| AnthropicCV, | |||
| SILICONFLOWCV, | |||
| GPUStackCV, | |||
| GoogleCV, | |||
| ) | |||
| import importlib | |||
| import inspect | |||
| from .rerank_model import ( | |||
| LocalAIRerank, | |||
| DefaultRerank, | |||
| JinaRerank, | |||
| YoudaoRerank, | |||
| XInferenceRerank, | |||
| NvidiaRerank, | |||
| LmStudioRerank, | |||
| OpenAI_APIRerank, | |||
| CoHereRerank, | |||
| TogetherAIRerank, | |||
| SILICONFLOWRerank, | |||
| BaiduYiyanRerank, | |||
| VoyageRerank, | |||
| QWenRerank, | |||
| GPUStackRerank, | |||
| HuggingfaceRerank, | |||
| NovitaRerank, | |||
| GiteeRerank | |||
| ) | |||
| ChatModel = globals().get("ChatModel", {}) | |||
| CvModel = globals().get("CvModel", {}) | |||
| EmbeddingModel = globals().get("EmbeddingModel", {}) | |||
| RerankModel = globals().get("RerankModel", {}) | |||
| Seq2txtModel = globals().get("Seq2txtModel", {}) | |||
| TTSModel = globals().get("TTSModel", {}) | |||
| from .sequence2txt_model import ( | |||
| GPTSeq2txt, | |||
| QWenSeq2txt, | |||
| AzureSeq2txt, | |||
| XinferenceSeq2txt, | |||
| TencentCloudSeq2txt, | |||
| GPUStackSeq2txt, | |||
| GiteeSeq2txt | |||
| ) | |||
| from .tts_model import ( | |||
| FishAudioTTS, | |||
| QwenTTS, | |||
| OpenAITTS, | |||
| SparkTTS, | |||
| XinferenceTTS, | |||
| GPUStackTTS, | |||
| SILICONFLOWTTS, | |||
| ) | |||
| EmbeddingModel = { | |||
| "Ollama": OllamaEmbed, | |||
| "LocalAI": LocalAIEmbed, | |||
| "OpenAI": OpenAIEmbed, | |||
| "Azure-OpenAI": AzureEmbed, | |||
| "Xinference": XinferenceEmbed, | |||
| "Tongyi-Qianwen": QWenEmbed, | |||
| "ZHIPU-AI": ZhipuEmbed, | |||
| "FastEmbed": FastEmbed, | |||
| "Youdao": YoudaoEmbed, | |||
| "BaiChuan": BaiChuanEmbed, | |||
| "Jina": JinaEmbed, | |||
| "BAAI": DefaultEmbedding, | |||
| "Mistral": MistralEmbed, | |||
| "Bedrock": BedrockEmbed, | |||
| "Gemini": GeminiEmbed, | |||
| "NVIDIA": NvidiaEmbed, | |||
| "LM-Studio": LmStudioEmbed, | |||
| "OpenAI-API-Compatible": OpenAI_APIEmbed, | |||
| "VLLM": OpenAI_APIEmbed, | |||
| "Cohere": CoHereEmbed, | |||
| "TogetherAI": TogetherAIEmbed, | |||
| "PerfXCloud": PerfXCloudEmbed, | |||
| "Upstage": UpstageEmbed, | |||
| "SILICONFLOW": SILICONFLOWEmbed, | |||
| "Replicate": ReplicateEmbed, | |||
| "BaiduYiyan": BaiduYiyanEmbed, | |||
| "Voyage AI": VoyageEmbed, | |||
| "HuggingFace": HuggingFaceEmbed, | |||
| "VolcEngine": VolcEngineEmbed, | |||
| "GPUStack": GPUStackEmbed, | |||
| "NovitaAI": NovitaEmbed, | |||
| "GiteeAI": GiteeEmbed | |||
| MODULE_MAPPING = { | |||
| "chat_model": ChatModel, | |||
| "cv_model": CvModel, | |||
| "embedding_model": EmbeddingModel, | |||
| "rerank_model": RerankModel, | |||
| "sequence2txt_model": Seq2txtModel, | |||
| "tts_model": TTSModel, | |||
| } | |||
| CvModel = { | |||
| "OpenAI": GptV4, | |||
| "Azure-OpenAI": AzureGptV4, | |||
| "Ollama": OllamaCV, | |||
| "Xinference": XinferenceCV, | |||
| "Tongyi-Qianwen": QWenCV, | |||
| "ZHIPU-AI": Zhipu4V, | |||
| "Moonshot": LocalCV, | |||
| "Gemini": GeminiCV, | |||
| "OpenRouter": OpenRouterCV, | |||
| "LocalAI": LocalAICV, | |||
| "NVIDIA": NvidiaCV, | |||
| "LM-Studio": LmStudioCV, | |||
| "StepFun": StepFunCV, | |||
| "OpenAI-API-Compatible": OpenAI_APICV, | |||
| "VLLM": OpenAI_APICV, | |||
| "TogetherAI": TogetherAICV, | |||
| "01.AI": YiCV, | |||
| "Tencent Hunyuan": HunyuanCV, | |||
| "Anthropic": AnthropicCV, | |||
| "SILICONFLOW": SILICONFLOWCV, | |||
| "GPUStack": GPUStackCV, | |||
| "Google Cloud": GoogleCV | |||
| } | |||
| package_name = __name__ | |||
| ChatModel = { | |||
| "OpenAI": GptTurbo, | |||
| "Azure-OpenAI": AzureChat, | |||
| "ZHIPU-AI": ZhipuChat, | |||
| "Tongyi-Qianwen": QWenChat, | |||
| "Ollama": OllamaChat, | |||
| "LocalAI": LocalAIChat, | |||
| "Xinference": XinferenceChat, | |||
| "Moonshot": MoonshotChat, | |||
| "DeepSeek": DeepSeekChat, | |||
| "VolcEngine": VolcEngineChat, | |||
| "BaiChuan": BaiChuanChat, | |||
| "MiniMax": MiniMaxChat, | |||
| "Mistral": MistralChat, | |||
| "Gemini": GeminiChat, | |||
| "Bedrock": BedrockChat, | |||
| "Groq": GroqChat, | |||
| "OpenRouter": OpenRouterChat, | |||
| "StepFun": StepFunChat, | |||
| "NVIDIA": NvidiaChat, | |||
| "LM-Studio": LmStudioChat, | |||
| "OpenAI-API-Compatible": OpenAI_APIChat, | |||
| "VLLM": OpenAI_APIChat, | |||
| "Cohere": CoHereChat, | |||
| "LeptonAI": LeptonAIChat, | |||
| "TogetherAI": TogetherAIChat, | |||
| "PerfXCloud": PerfXCloudChat, | |||
| "Upstage": UpstageChat, | |||
| "NovitaAI": NovitaAIChat, | |||
| "SILICONFLOW": SILICONFLOWChat, | |||
| "PPIO": PPIOChat, | |||
| "01.AI": YiChat, | |||
| "Replicate": ReplicateChat, | |||
| "Tencent Hunyuan": HunyuanChat, | |||
| "XunFei Spark": SparkChat, | |||
| "BaiduYiyan": BaiduYiyanChat, | |||
| "Anthropic": AnthropicChat, | |||
| "Google Cloud": GoogleChat, | |||
| "HuggingFace": HuggingFaceChat, | |||
| "GPUStack": GPUStackChat, | |||
| "ModelScope":ModelScopeChat, | |||
| "GiteeAI": GiteeChat | |||
| } | |||
| for module_name, mapping_dict in MODULE_MAPPING.items(): | |||
| full_module_name = f"{package_name}.{module_name}" | |||
| module = importlib.import_module(full_module_name) | |||
| RerankModel = { | |||
| "LocalAI": LocalAIRerank, | |||
| "BAAI": DefaultRerank, | |||
| "Jina": JinaRerank, | |||
| "Youdao": YoudaoRerank, | |||
| "Xinference": XInferenceRerank, | |||
| "NVIDIA": NvidiaRerank, | |||
| "LM-Studio": LmStudioRerank, | |||
| "OpenAI-API-Compatible": OpenAI_APIRerank, | |||
| "VLLM": CoHereRerank, | |||
| "Cohere": CoHereRerank, | |||
| "TogetherAI": TogetherAIRerank, | |||
| "SILICONFLOW": SILICONFLOWRerank, | |||
| "BaiduYiyan": BaiduYiyanRerank, | |||
| "Voyage AI": VoyageRerank, | |||
| "Tongyi-Qianwen": QWenRerank, | |||
| "GPUStack": GPUStackRerank, | |||
| "HuggingFace": HuggingfaceRerank, | |||
| "NovitaAI": NovitaRerank, | |||
| "GiteeAI": GiteeRerank | |||
| } | |||
| base_class = None | |||
| for name, obj in inspect.getmembers(module): | |||
| if inspect.isclass(obj) and name == "Base": | |||
| base_class = obj | |||
| break | |||
| if base_class is None: | |||
| continue | |||
| Seq2txtModel = { | |||
| "OpenAI": GPTSeq2txt, | |||
| "Tongyi-Qianwen": QWenSeq2txt, | |||
| "Azure-OpenAI": AzureSeq2txt, | |||
| "Xinference": XinferenceSeq2txt, | |||
| "Tencent Cloud": TencentCloudSeq2txt, | |||
| "GPUStack": GPUStackSeq2txt, | |||
| "GiteeAI": GiteeSeq2txt | |||
| } | |||
| for _, obj in inspect.getmembers(module): | |||
| if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"): | |||
| if isinstance(obj._FACTORY_NAME, list): | |||
| for factory_name in obj._FACTORY_NAME: | |||
| mapping_dict[factory_name] = obj | |||
| else: | |||
| mapping_dict[obj._FACTORY_NAME] = obj | |||
| TTSModel = { | |||
| "Fish Audio": FishAudioTTS, | |||
| "Tongyi-Qianwen": QwenTTS, | |||
| "OpenAI": OpenAITTS, | |||
| "XunFei Spark": SparkTTS, | |||
| "Xinference": XinferenceTTS, | |||
| "GPUStack": GPUStackTTS, | |||
| "SILICONFLOW": SILICONFLOWTTS, | |||
| } | |||
| __all__ = [ | |||
| "ChatModel", | |||
| "CvModel", | |||
| "EmbeddingModel", | |||
| "RerankModel", | |||
| "Seq2txtModel", | |||
| "TTSModel", | |||
| ] | |||
| @@ -142,11 +142,7 @@ class Base(ABC): | |||
| return f"{ERROR_PREFIX}: {error_code} - {str(e)}" | |||
| def _verbose_tool_use(self, name, args, res): | |||
| return "<tool_call>" + json.dumps({ | |||
| "name": name, | |||
| "args": args, | |||
| "result": res | |||
| }, ensure_ascii=False, indent=2) + "</tool_call>" | |||
| return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>" | |||
| def _append_history(self, hist, tool_call, tool_res): | |||
| hist.append( | |||
| @@ -191,10 +187,10 @@ class Base(ABC): | |||
| tk_count = 0 | |||
| hist = deepcopy(history) | |||
| # Implement exponential backoff retry strategy | |||
| for attempt in range(self.max_retries+1): | |||
| for attempt in range(self.max_retries + 1): | |||
| history = hist | |||
| try: | |||
| for _ in range(self.max_rounds*2): | |||
| for _ in range(self.max_rounds * 2): | |||
| response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf) | |||
| tk_count += self.total_token_count(response) | |||
| if any([not response.choices, not response.choices[0].message]): | |||
| @@ -269,7 +265,7 @@ class Base(ABC): | |||
| for attempt in range(self.max_retries + 1): | |||
| history = hist | |||
| try: | |||
| for _ in range(self.max_rounds*2): | |||
| for _ in range(self.max_rounds * 2): | |||
| reasoning_start = False | |||
| response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) | |||
| final_tool_calls = {} | |||
| @@ -430,6 +426,8 @@ class Base(ABC): | |||
| class GptTurbo(Base): | |||
| _FACTORY_NAME = "OpenAI" | |||
| def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.openai.com/v1" | |||
| @@ -437,6 +435,8 @@ class GptTurbo(Base): | |||
| class MoonshotChat(Base): | |||
| _FACTORY_NAME = "Moonshot" | |||
| def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.moonshot.cn/v1" | |||
| @@ -444,6 +444,8 @@ class MoonshotChat(Base): | |||
| class XinferenceChat(Base): | |||
| _FACTORY_NAME = "Xinference" | |||
| def __init__(self, key=None, model_name="", base_url="", **kwargs): | |||
| if not base_url: | |||
| raise ValueError("Local llm url cannot be None") | |||
| @@ -452,6 +454,8 @@ class XinferenceChat(Base): | |||
| class HuggingFaceChat(Base): | |||
| _FACTORY_NAME = "HuggingFace" | |||
| def __init__(self, key=None, model_name="", base_url="", **kwargs): | |||
| if not base_url: | |||
| raise ValueError("Local llm url cannot be None") | |||
| @@ -460,6 +464,8 @@ class HuggingFaceChat(Base): | |||
| class ModelScopeChat(Base): | |||
| _FACTORY_NAME = "ModelScope" | |||
| def __init__(self, key=None, model_name="", base_url="", **kwargs): | |||
| if not base_url: | |||
| raise ValueError("Local llm url cannot be None") | |||
| @@ -468,6 +474,8 @@ class ModelScopeChat(Base): | |||
| class DeepSeekChat(Base): | |||
| _FACTORY_NAME = "DeepSeek" | |||
| def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.deepseek.com/v1" | |||
| @@ -475,6 +483,8 @@ class DeepSeekChat(Base): | |||
| class AzureChat(Base): | |||
| _FACTORY_NAME = "Azure-OpenAI" | |||
| def __init__(self, key, model_name, base_url, **kwargs): | |||
| api_key = json.loads(key).get("api_key", "") | |||
| api_version = json.loads(key).get("api_version", "2024-02-01") | |||
| @@ -484,6 +494,8 @@ class AzureChat(Base): | |||
| class BaiChuanChat(Base): | |||
| _FACTORY_NAME = "BaiChuan" | |||
| def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.baichuan-ai.com/v1" | |||
| @@ -557,6 +569,8 @@ class BaiChuanChat(Base): | |||
| class QWenChat(Base): | |||
| _FACTORY_NAME = "Tongyi-Qianwen" | |||
| def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs): | |||
| if not base_url: | |||
| base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" | |||
| @@ -565,6 +579,8 @@ class QWenChat(Base): | |||
| class ZhipuChat(Base): | |||
| _FACTORY_NAME = "ZHIPU-AI" | |||
| def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -630,6 +646,8 @@ class ZhipuChat(Base): | |||
| class OllamaChat(Base): | |||
| _FACTORY_NAME = "Ollama" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -694,6 +712,8 @@ class OllamaChat(Base): | |||
| class LocalAIChat(Base): | |||
| _FACTORY_NAME = "LocalAI" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -752,6 +772,8 @@ class LocalLLM(Base): | |||
| class VolcEngineChat(Base): | |||
| _FACTORY_NAME = "VolcEngine" | |||
| def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs): | |||
| """ | |||
| Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, | |||
| @@ -765,6 +787,8 @@ class VolcEngineChat(Base): | |||
| class MiniMaxChat(Base): | |||
| _FACTORY_NAME = "MiniMax" | |||
| def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -843,6 +867,8 @@ class MiniMaxChat(Base): | |||
| class MistralChat(Base): | |||
| _FACTORY_NAME = "Mistral" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -896,6 +922,8 @@ class MistralChat(Base): | |||
| class BedrockChat(Base): | |||
| _FACTORY_NAME = "Bedrock" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -978,6 +1006,8 @@ class BedrockChat(Base): | |||
| class GeminiChat(Base): | |||
| _FACTORY_NAME = "Gemini" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -997,6 +1027,7 @@ class GeminiChat(Base): | |||
| def _chat(self, history, gen_conf): | |||
| from google.generativeai.types import content_types | |||
| system = history[0]["content"] if history and history[0]["role"] == "system" else "" | |||
| hist = [] | |||
| for item in history: | |||
| @@ -1019,6 +1050,7 @@ class GeminiChat(Base): | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| from google.generativeai.types import content_types | |||
| gen_conf = self._clean_conf(gen_conf) | |||
| if system: | |||
| self.model._system_instruction = content_types.to_content(system) | |||
| @@ -1042,6 +1074,8 @@ class GeminiChat(Base): | |||
| class GroqChat(Base): | |||
| _FACTORY_NAME = "Groq" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -1086,6 +1120,8 @@ class GroqChat(Base): | |||
| ## openrouter | |||
| class OpenRouterChat(Base): | |||
| _FACTORY_NAME = "OpenRouter" | |||
| def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://openrouter.ai/api/v1" | |||
| @@ -1093,6 +1129,8 @@ class OpenRouterChat(Base): | |||
| class StepFunChat(Base): | |||
| _FACTORY_NAME = "StepFun" | |||
| def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.stepfun.com/v1" | |||
| @@ -1100,6 +1138,8 @@ class StepFunChat(Base): | |||
| class NvidiaChat(Base): | |||
| _FACTORY_NAME = "NVIDIA" | |||
| def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://integrate.api.nvidia.com/v1" | |||
| @@ -1107,6 +1147,8 @@ class NvidiaChat(Base): | |||
| class LmStudioChat(Base): | |||
| _FACTORY_NAME = "LM-Studio" | |||
| def __init__(self, key, model_name, base_url, **kwargs): | |||
| if not base_url: | |||
| raise ValueError("Local llm url cannot be None") | |||
| @@ -1117,6 +1159,8 @@ class LmStudioChat(Base): | |||
| class OpenAI_APIChat(Base): | |||
| _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"] | |||
| def __init__(self, key, model_name, base_url): | |||
| if not base_url: | |||
| raise ValueError("url cannot be None") | |||
| @@ -1125,6 +1169,8 @@ class OpenAI_APIChat(Base): | |||
| class PPIOChat(Base): | |||
| _FACTORY_NAME = "PPIO" | |||
| def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.ppinfra.com/v3/openai" | |||
| @@ -1132,6 +1178,8 @@ class PPIOChat(Base): | |||
| class CoHereChat(Base): | |||
| _FACTORY_NAME = "Cohere" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -1207,6 +1255,8 @@ class CoHereChat(Base): | |||
| class LeptonAIChat(Base): | |||
| _FACTORY_NAME = "LeptonAI" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| if not base_url: | |||
| base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1") | |||
| @@ -1214,6 +1264,8 @@ class LeptonAIChat(Base): | |||
| class TogetherAIChat(Base): | |||
| _FACTORY_NAME = "TogetherAI" | |||
| def __init__(self, key, model_name, base_url="https://api.together.xyz/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.together.xyz/v1" | |||
| @@ -1221,6 +1273,8 @@ class TogetherAIChat(Base): | |||
| class PerfXCloudChat(Base): | |||
| _FACTORY_NAME = "PerfXCloud" | |||
| def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://cloud.perfxlab.cn/v1" | |||
| @@ -1228,6 +1282,8 @@ class PerfXCloudChat(Base): | |||
| class UpstageChat(Base): | |||
| _FACTORY_NAME = "Upstage" | |||
| def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.upstage.ai/v1/solar" | |||
| @@ -1235,6 +1291,8 @@ class UpstageChat(Base): | |||
| class NovitaAIChat(Base): | |||
| _FACTORY_NAME = "NovitaAI" | |||
| def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.novita.ai/v3/openai" | |||
| @@ -1242,6 +1300,8 @@ class NovitaAIChat(Base): | |||
| class SILICONFLOWChat(Base): | |||
| _FACTORY_NAME = "SILICONFLOW" | |||
| def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.siliconflow.cn/v1" | |||
| @@ -1249,6 +1309,8 @@ class SILICONFLOWChat(Base): | |||
| class YiChat(Base): | |||
| _FACTORY_NAME = "01.AI" | |||
| def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.lingyiwanwu.com/v1" | |||
| @@ -1256,6 +1318,8 @@ class YiChat(Base): | |||
| class GiteeChat(Base): | |||
| _FACTORY_NAME = "GiteeAI" | |||
| def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://ai.gitee.com/v1/" | |||
| @@ -1263,6 +1327,8 @@ class GiteeChat(Base): | |||
| class ReplicateChat(Base): | |||
| _FACTORY_NAME = "Replicate" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -1302,6 +1368,8 @@ class ReplicateChat(Base): | |||
| class HunyuanChat(Base): | |||
| _FACTORY_NAME = "Tencent Hunyuan" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -1378,6 +1446,8 @@ class HunyuanChat(Base): | |||
| class SparkChat(Base): | |||
| _FACTORY_NAME = "XunFei Spark" | |||
| def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://spark-api-open.xf-yun.com/v1" | |||
| @@ -1398,6 +1468,8 @@ class SparkChat(Base): | |||
| class BaiduYiyanChat(Base): | |||
| _FACTORY_NAME = "BaiduYiyan" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -1444,6 +1516,8 @@ class BaiduYiyanChat(Base): | |||
| class AnthropicChat(Base): | |||
| _FACTORY_NAME = "Anthropic" | |||
| def __init__(self, key, model_name, base_url="https://api.anthropic.com/v1/", **kwargs): | |||
| if not base_url: | |||
| base_url = "https://api.anthropic.com/v1/" | |||
| @@ -1451,6 +1525,8 @@ class AnthropicChat(Base): | |||
| class GoogleChat(Base): | |||
| _FACTORY_NAME = "Google Cloud" | |||
| def __init__(self, key, model_name, base_url=None, **kwargs): | |||
| super().__init__(key, model_name, base_url=base_url, **kwargs) | |||
| @@ -1529,9 +1605,11 @@ class GoogleChat(Base): | |||
| if "role" in item and item["role"] == "assistant": | |||
| item["role"] = "model" | |||
| if "content" in item: | |||
| item["parts"] = [{ | |||
| "text": item.pop("content"), | |||
| }] | |||
| item["parts"] = [ | |||
| { | |||
| "text": item.pop("content"), | |||
| } | |||
| ] | |||
| response = self.client.generate_content(hist, generation_config=gen_conf) | |||
| ans = response.text | |||
| @@ -1587,8 +1665,10 @@ class GoogleChat(Base): | |||
| class GPUStackChat(Base): | |||
| _FACTORY_NAME = "GPUStack" | |||
| def __init__(self, key=None, model_name="", base_url="", **kwargs): | |||
| if not base_url: | |||
| raise ValueError("Local llm url cannot be None") | |||
| base_url = urljoin(base_url, "v1") | |||
| super().__init__(key, model_name, base_url, **kwargs) | |||
| super().__init__(key, model_name, base_url, **kwargs) | |||
| @@ -57,7 +57,7 @@ class Base(ABC): | |||
| model=self.model_name, | |||
| messages=history, | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7) | |||
| top_p=gen_conf.get("top_p", 0.7), | |||
| ) | |||
| return response.choices[0].message.content.strip(), response.usage.total_tokens | |||
| except Exception as e: | |||
| @@ -79,7 +79,7 @@ class Base(ABC): | |||
| messages=history, | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7), | |||
| stream=True | |||
| stream=True, | |||
| ) | |||
| for resp in response: | |||
| if not resp.choices[0].delta.content: | |||
| @@ -87,8 +87,7 @@ class Base(ABC): | |||
| delta = resp.choices[0].delta.content | |||
| ans += delta | |||
| if resp.choices[0].finish_reason == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| tk_count = resp.usage.total_tokens | |||
| if resp.choices[0].finish_reason == "stop": | |||
| tk_count = resp.usage.total_tokens | |||
| @@ -117,13 +116,12 @@ class Base(ABC): | |||
| "content": [ | |||
| { | |||
| "type": "image_url", | |||
| "image_url": { | |||
| "url": f"data:image/jpeg;base64,{b64}" | |||
| }, | |||
| "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, | |||
| }, | |||
| { | |||
| "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else | |||
| "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | |||
| "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" | |||
| if self.lang.lower() == "chinese" | |||
| else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | |||
| }, | |||
| ], | |||
| } | |||
| @@ -136,9 +134,7 @@ class Base(ABC): | |||
| "content": [ | |||
| { | |||
| "type": "image_url", | |||
| "image_url": { | |||
| "url": f"data:image/jpeg;base64,{b64}" | |||
| }, | |||
| "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, | |||
| }, | |||
| { | |||
| "type": "text", | |||
| @@ -156,14 +152,13 @@ class Base(ABC): | |||
| "url": f"data:image/jpeg;base64,{b64}", | |||
| }, | |||
| }, | |||
| { | |||
| "type": "text", | |||
| "text": text | |||
| }, | |||
| {"type": "text", "text": text}, | |||
| ] | |||
| class GptV4(Base): | |||
| _FACTORY_NAME = "OpenAI" | |||
| def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): | |||
| if not base_url: | |||
| base_url = "https://api.openai.com/v1" | |||
| @@ -181,7 +176,7 @@ class GptV4(Base): | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=prompt | |||
| messages=prompt, | |||
| ) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| @@ -197,9 +192,11 @@ class GptV4(Base): | |||
| class AzureGptV4(Base): | |||
| _FACTORY_NAME = "Azure-OpenAI" | |||
| def __init__(self, key, model_name, lang="Chinese", **kwargs): | |||
| api_key = json.loads(key).get('api_key', '') | |||
| api_version = json.loads(key).get('api_version', '2024-02-01') | |||
| api_key = json.loads(key).get("api_key", "") | |||
| api_version = json.loads(key).get("api_version", "2024-02-01") | |||
| self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| @@ -212,10 +209,7 @@ class AzureGptV4(Base): | |||
| if "text" in c: | |||
| c["type"] = "text" | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=prompt | |||
| ) | |||
| res = self.client.chat.completions.create(model=self.model_name, messages=prompt) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| @@ -230,8 +224,11 @@ class AzureGptV4(Base): | |||
| class QWenCV(Base): | |||
| _FACTORY_NAME = "Tongyi-Qianwen" | |||
| def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs): | |||
| import dashscope | |||
| dashscope.api_key = key | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| @@ -247,12 +244,11 @@ class QWenCV(Base): | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"image": f"file://{path}"}, | |||
| { | |||
| "image": f"file://{path}" | |||
| }, | |||
| { | |||
| "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else | |||
| "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | |||
| "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" | |||
| if self.lang.lower() == "chinese" | |||
| else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | |||
| }, | |||
| ], | |||
| } | |||
| @@ -270,11 +266,9 @@ class QWenCV(Base): | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| {"image": f"file://{path}"}, | |||
| { | |||
| "image": f"file://{path}" | |||
| }, | |||
| { | |||
| "text": prompt if prompt else vision_llm_describe_prompt(), | |||
| "text": prompt if prompt else vision_llm_describe_prompt(), | |||
| }, | |||
| ], | |||
| } | |||
| @@ -290,9 +284,10 @@ class QWenCV(Base): | |||
| from http import HTTPStatus | |||
| from dashscope import MultiModalConversation | |||
| response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image)) | |||
| if response.status_code == HTTPStatus.OK: | |||
| return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens | |||
| return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens | |||
| return response.message, 0 | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| @@ -303,33 +298,36 @@ class QWenCV(Base): | |||
| vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image) | |||
| response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt) | |||
| if response.status_code == HTTPStatus.OK: | |||
| return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens | |||
| return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens | |||
| return response.message, 0 | |||
| def chat(self, system, history, gen_conf, image=""): | |||
| from http import HTTPStatus | |||
| from dashscope import MultiModalConversation | |||
| if system: | |||
| history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] | |||
| for his in history: | |||
| if his["role"] == "user": | |||
| his["content"] = self.chat_prompt(his["content"], image) | |||
| response = MultiModalConversation.call(model=self.model_name, messages=history, | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7)) | |||
| response = MultiModalConversation.call( | |||
| model=self.model_name, | |||
| messages=history, | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7), | |||
| ) | |||
| ans = "" | |||
| tk_count = 0 | |||
| if response.status_code == HTTPStatus.OK: | |||
| ans = response.output.choices[0]['message']['content'] | |||
| ans = response.output.choices[0]["message"]["content"] | |||
| if isinstance(ans, list): | |||
| ans = ans[0]["text"] if ans else "" | |||
| tk_count += response.usage.total_tokens | |||
| if response.output.choices[0].get("finish_reason", "") == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| return ans, tk_count | |||
| return "**ERROR**: " + response.message, tk_count | |||
| @@ -338,6 +336,7 @@ class QWenCV(Base): | |||
| from http import HTTPStatus | |||
| from dashscope import MultiModalConversation | |||
| if system: | |||
| history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] | |||
| @@ -348,24 +347,25 @@ class QWenCV(Base): | |||
| ans = "" | |||
| tk_count = 0 | |||
| try: | |||
| response = MultiModalConversation.call(model=self.model_name, messages=history, | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7), | |||
| stream=True) | |||
| response = MultiModalConversation.call( | |||
| model=self.model_name, | |||
| messages=history, | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7), | |||
| stream=True, | |||
| ) | |||
| for resp in response: | |||
| if resp.status_code == HTTPStatus.OK: | |||
| cnt = resp.output.choices[0]['message']['content'] | |||
| cnt = resp.output.choices[0]["message"]["content"] | |||
| if isinstance(cnt, list): | |||
| cnt = cnt[0]["text"] if ans else "" | |||
| ans += cnt | |||
| tk_count = resp.usage.total_tokens | |||
| if resp.output.choices[0].get("finish_reason", "") == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| yield ans | |||
| else: | |||
| yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find( | |||
| "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**" | |||
| yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**" | |||
| except Exception as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| @@ -373,6 +373,8 @@ class QWenCV(Base): | |||
| class Zhipu4V(Base): | |||
| _FACTORY_NAME = "ZHIPU-AI" | |||
| def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): | |||
| self.client = ZhipuAI(api_key=key) | |||
| self.model_name = model_name | |||
| @@ -394,10 +396,7 @@ class Zhipu4V(Base): | |||
| b64 = self.image2base64(image) | |||
| vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=vision_prompt | |||
| ) | |||
| res = self.client.chat.completions.create(model=self.model_name, messages=vision_prompt) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| def chat(self, system, history, gen_conf, image=""): | |||
| @@ -412,7 +411,7 @@ class Zhipu4V(Base): | |||
| model=self.model_name, | |||
| messages=history, | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7) | |||
| top_p=gen_conf.get("top_p", 0.7), | |||
| ) | |||
| return response.choices[0].message.content.strip(), response.usage.total_tokens | |||
| except Exception as e: | |||
| @@ -434,7 +433,7 @@ class Zhipu4V(Base): | |||
| messages=history, | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7), | |||
| stream=True | |||
| stream=True, | |||
| ) | |||
| for resp in response: | |||
| if not resp.choices[0].delta.content: | |||
| @@ -442,8 +441,7 @@ class Zhipu4V(Base): | |||
| delta = resp.choices[0].delta.content | |||
| ans += delta | |||
| if resp.choices[0].finish_reason == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| tk_count = resp.usage.total_tokens | |||
| if resp.choices[0].finish_reason == "stop": | |||
| tk_count = resp.usage.total_tokens | |||
| @@ -455,6 +453,8 @@ class Zhipu4V(Base): | |||
| class OllamaCV(Base): | |||
| _FACTORY_NAME = "Ollama" | |||
| def __init__(self, key, model_name, lang="Chinese", **kwargs): | |||
| self.client = Client(host=kwargs["base_url"]) | |||
| self.model_name = model_name | |||
| @@ -466,7 +466,7 @@ class OllamaCV(Base): | |||
| response = self.client.generate( | |||
| model=self.model_name, | |||
| prompt=prompt[0]["content"][1]["text"], | |||
| images=[image] | |||
| images=[image], | |||
| ) | |||
| ans = response["response"].strip() | |||
| return ans, 128 | |||
| @@ -507,7 +507,7 @@ class OllamaCV(Base): | |||
| model=self.model_name, | |||
| messages=history, | |||
| options=options, | |||
| keep_alive=-1 | |||
| keep_alive=-1, | |||
| ) | |||
| ans = response["message"]["content"].strip() | |||
| @@ -538,7 +538,7 @@ class OllamaCV(Base): | |||
| messages=history, | |||
| stream=True, | |||
| options=options, | |||
| keep_alive=-1 | |||
| keep_alive=-1, | |||
| ) | |||
| for resp in response: | |||
| if resp["done"]: | |||
| @@ -551,6 +551,8 @@ class OllamaCV(Base): | |||
| class LocalAICV(GptV4): | |||
| _FACTORY_NAME = "LocalAI" | |||
| def __init__(self, key, model_name, base_url, lang="Chinese"): | |||
| if not base_url: | |||
| raise ValueError("Local cv model url cannot be None") | |||
| @@ -561,6 +563,8 @@ class LocalAICV(GptV4): | |||
| class XinferenceCV(Base): | |||
| _FACTORY_NAME = "Xinference" | |||
| def __init__(self, key, model_name="", lang="Chinese", base_url=""): | |||
| base_url = urljoin(base_url, "v1") | |||
| self.client = OpenAI(api_key=key, base_url=base_url) | |||
| @@ -570,10 +574,7 @@ class XinferenceCV(Base): | |||
| def describe(self, image): | |||
| b64 = self.image2base64(image) | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=self.prompt(b64) | |||
| ) | |||
| res = self.client.chat.completions.create(model=self.model_name, messages=self.prompt(b64)) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| @@ -588,8 +589,11 @@ class XinferenceCV(Base): | |||
| class GeminiCV(Base): | |||
| _FACTORY_NAME = "Gemini" | |||
| def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): | |||
| from google.generativeai import GenerativeModel, client | |||
| client.configure(api_key=key) | |||
| _client = client.get_default_generative_client() | |||
| self.model_name = model_name | |||
| @@ -599,18 +603,21 @@ class GeminiCV(Base): | |||
| def describe(self, image): | |||
| from PIL.Image import open | |||
| prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ | |||
| "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." | |||
| prompt = ( | |||
| "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" | |||
| if self.lang.lower() == "chinese" | |||
| else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." | |||
| ) | |||
| b64 = self.image2base64(image) | |||
| img = open(BytesIO(base64.b64decode(b64))) | |||
| input = [prompt, img] | |||
| res = self.model.generate_content( | |||
| input | |||
| ) | |||
| res = self.model.generate_content(input) | |||
| return res.text, res.usage_metadata.total_token_count | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| from PIL.Image import open | |||
| b64 = self.image2base64(image) | |||
| vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) | |||
| img = open(BytesIO(base64.b64decode(b64))) | |||
| @@ -622,6 +629,7 @@ class GeminiCV(Base): | |||
| def chat(self, system, history, gen_conf, image=""): | |||
| from transformers import GenerationConfig | |||
| if system: | |||
| history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] | |||
| try: | |||
| @@ -635,9 +643,7 @@ class GeminiCV(Base): | |||
| his.pop("content") | |||
| history[-1]["parts"].append("data:image/jpeg;base64," + image) | |||
| response = self.model.generate_content(history, generation_config=GenerationConfig( | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7))) | |||
| response = self.model.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))) | |||
| ans = response.text | |||
| return ans, response.usage_metadata.total_token_count | |||
| @@ -646,6 +652,7 @@ class GeminiCV(Base): | |||
| def chat_streamly(self, system, history, gen_conf, image=""): | |||
| from transformers import GenerationConfig | |||
| if system: | |||
| history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] | |||
| @@ -661,9 +668,11 @@ class GeminiCV(Base): | |||
| his.pop("content") | |||
| history[-1]["parts"].append("data:image/jpeg;base64," + image) | |||
| response = self.model.generate_content(history, generation_config=GenerationConfig( | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7)), stream=True) | |||
| response = self.model.generate_content( | |||
| history, | |||
| generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)), | |||
| stream=True, | |||
| ) | |||
| for resp in response: | |||
| if not resp.text: | |||
| @@ -677,6 +686,8 @@ class GeminiCV(Base): | |||
| class OpenRouterCV(GptV4): | |||
| _FACTORY_NAME = "OpenRouter" | |||
| def __init__( | |||
| self, | |||
| key, | |||
| @@ -692,6 +703,8 @@ class OpenRouterCV(GptV4): | |||
| class LocalCV(Base): | |||
| _FACTORY_NAME = "Moonshot" | |||
| def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): | |||
| pass | |||
| @@ -700,6 +713,8 @@ class LocalCV(Base): | |||
| class NvidiaCV(Base): | |||
| _FACTORY_NAME = "NVIDIA" | |||
| def __init__( | |||
| self, | |||
| key, | |||
| @@ -726,9 +741,7 @@ class NvidiaCV(Base): | |||
| "content-type": "application/json", | |||
| "Authorization": f"Bearer {self.key}", | |||
| }, | |||
| json={ | |||
| "messages": self.prompt(b64) | |||
| }, | |||
| json={"messages": self.prompt(b64)}, | |||
| ) | |||
| response = response.json() | |||
| return ( | |||
| @@ -774,10 +787,7 @@ class NvidiaCV(Base): | |||
| return [ | |||
| { | |||
| "role": "user", | |||
| "content": ( | |||
| prompt if prompt else vision_llm_describe_prompt() | |||
| ) | |||
| + f' <img src="data:image/jpeg;base64,{b64}"/>', | |||
| "content": (prompt if prompt else vision_llm_describe_prompt()) + f' <img src="data:image/jpeg;base64,{b64}"/>', | |||
| } | |||
| ] | |||
| @@ -791,6 +801,8 @@ class NvidiaCV(Base): | |||
| class StepFunCV(GptV4): | |||
| _FACTORY_NAME = "StepFun" | |||
| def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"): | |||
| if not base_url: | |||
| base_url = "https://api.stepfun.com/v1" | |||
| @@ -800,6 +812,8 @@ class StepFunCV(GptV4): | |||
| class LmStudioCV(GptV4): | |||
| _FACTORY_NAME = "LM-Studio" | |||
| def __init__(self, key, model_name, lang="Chinese", base_url=""): | |||
| if not base_url: | |||
| raise ValueError("Local llm url cannot be None") | |||
| @@ -810,6 +824,8 @@ class LmStudioCV(GptV4): | |||
| class OpenAI_APICV(GptV4): | |||
| _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"] | |||
| def __init__(self, key, model_name, lang="Chinese", base_url=""): | |||
| if not base_url: | |||
| raise ValueError("url cannot be None") | |||
| @@ -820,6 +836,8 @@ class OpenAI_APICV(GptV4): | |||
| class TogetherAICV(GptV4): | |||
| _FACTORY_NAME = "TogetherAI" | |||
| def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"): | |||
| if not base_url: | |||
| base_url = "https://api.together.xyz/v1" | |||
| @@ -827,20 +845,38 @@ class TogetherAICV(GptV4): | |||
| class YiCV(GptV4): | |||
| def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",): | |||
| _FACTORY_NAME = "01.AI" | |||
| def __init__( | |||
| self, | |||
| key, | |||
| model_name, | |||
| lang="Chinese", | |||
| base_url="https://api.lingyiwanwu.com/v1", | |||
| ): | |||
| if not base_url: | |||
| base_url = "https://api.lingyiwanwu.com/v1" | |||
| super().__init__(key, model_name, lang, base_url) | |||
| class SILICONFLOWCV(GptV4): | |||
| def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1",): | |||
| _FACTORY_NAME = "SILICONFLOW" | |||
| def __init__( | |||
| self, | |||
| key, | |||
| model_name, | |||
| lang="Chinese", | |||
| base_url="https://api.siliconflow.cn/v1", | |||
| ): | |||
| if not base_url: | |||
| base_url = "https://api.siliconflow.cn/v1" | |||
| super().__init__(key, model_name, lang, base_url) | |||
| class HunyuanCV(Base): | |||
| _FACTORY_NAME = "Tencent Hunyuan" | |||
| def __init__(self, key, model_name, lang="Chinese", base_url=None): | |||
| from tencentcloud.common import credential | |||
| from tencentcloud.hunyuan.v20230901 import hunyuan_client | |||
| @@ -895,14 +931,13 @@ class HunyuanCV(Base): | |||
| "Contents": [ | |||
| { | |||
| "Type": "image_url", | |||
| "ImageUrl": { | |||
| "Url": f"data:image/jpeg;base64,{b64}" | |||
| }, | |||
| "ImageUrl": {"Url": f"data:image/jpeg;base64,{b64}"}, | |||
| }, | |||
| { | |||
| "Type": "text", | |||
| "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else | |||
| "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | |||
| "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" | |||
| if self.lang.lower() == "chinese" | |||
| else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | |||
| }, | |||
| ], | |||
| } | |||
| @@ -910,6 +945,8 @@ class HunyuanCV(Base): | |||
| class AnthropicCV(Base): | |||
| _FACTORY_NAME = "Anthropic" | |||
| def __init__(self, key, model_name, base_url=None): | |||
| import anthropic | |||
| @@ -933,38 +970,29 @@ class AnthropicCV(Base): | |||
| "data": b64, | |||
| }, | |||
| }, | |||
| { | |||
| "type": "text", | |||
| "text": prompt | |||
| } | |||
| {"type": "text", "text": prompt}, | |||
| ], | |||
| } | |||
| ] | |||
| def describe(self, image): | |||
| b64 = self.image2base64(image) | |||
| prompt = self.prompt(b64, | |||
| "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else | |||
| "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." | |||
| ) | |||
| response = self.client.messages.create( | |||
| model=self.model_name, | |||
| max_tokens=self.max_tokens, | |||
| messages=prompt | |||
| prompt = self.prompt( | |||
| b64, | |||
| "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" | |||
| if self.lang.lower() == "chinese" | |||
| else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | |||
| ) | |||
| return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"] | |||
| response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt) | |||
| return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"] | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| b64 = self.image2base64(image) | |||
| prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt()) | |||
| response = self.client.messages.create( | |||
| model=self.model_name, | |||
| max_tokens=self.max_tokens, | |||
| messages=prompt | |||
| ) | |||
| return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"] | |||
| response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt) | |||
| return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"] | |||
| def chat(self, system, history, gen_conf): | |||
| if "presence_penalty" in gen_conf: | |||
| @@ -984,11 +1012,7 @@ class AnthropicCV(Base): | |||
| ).to_dict() | |||
| ans = response["content"][0]["text"] | |||
| if response["stop_reason"] == "max_tokens": | |||
| ans += ( | |||
| "...\nFor the content length reason, it stopped, continue?" | |||
| if is_english([ans]) | |||
| else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| ) | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| return ( | |||
| ans, | |||
| response["usage"]["input_tokens"] + response["usage"]["output_tokens"], | |||
| @@ -1014,7 +1038,7 @@ class AnthropicCV(Base): | |||
| **gen_conf, | |||
| ) | |||
| for res in response: | |||
| if res.type == 'content_block_delta': | |||
| if res.type == "content_block_delta": | |||
| if res.delta.type == "thinking_delta" and res.delta.thinking: | |||
| if ans.find("<think>") < 0: | |||
| ans += "<think>" | |||
| @@ -1030,7 +1054,10 @@ class AnthropicCV(Base): | |||
| yield total_tokens | |||
| class GPUStackCV(GptV4): | |||
| _FACTORY_NAME = "GPUStack" | |||
| def __init__(self, key, model_name, lang="Chinese", base_url=""): | |||
| if not base_url: | |||
| raise ValueError("Local llm url cannot be None") | |||
| @@ -1041,11 +1068,13 @@ class GPUStackCV(GptV4): | |||
| class GoogleCV(Base): | |||
| _FACTORY_NAME = "Google Cloud" | |||
| def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs): | |||
| import base64 | |||
| from google.oauth2 import service_account | |||
| key = json.loads(key) | |||
| access_token = json.loads(base64.b64decode(key.get("google_service_account_key", ""))) | |||
| project_id = key.get("google_project_id", "") | |||
| @@ -1079,9 +1108,12 @@ class GoogleCV(Base): | |||
| self.client = glm.GenerativeModel(model_name=self.model_name) | |||
| def describe(self, image): | |||
| prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ | |||
| "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." | |||
| prompt = ( | |||
| "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" | |||
| if self.lang.lower() == "chinese" | |||
| else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." | |||
| ) | |||
| if "claude" in self.model_name: | |||
| b64 = self.image2base64(image) | |||
| vision_prompt = [ | |||
| @@ -1096,28 +1128,22 @@ class GoogleCV(Base): | |||
| "data": b64, | |||
| }, | |||
| }, | |||
| { | |||
| "type": "text", | |||
| "text": prompt | |||
| } | |||
| {"type": "text", "text": prompt}, | |||
| ], | |||
| } | |||
| ] | |||
| response = self.client.messages.create( | |||
| model=self.model_name, | |||
| max_tokens=8192, | |||
| messages=vision_prompt | |||
| messages=vision_prompt, | |||
| ) | |||
| return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens | |||
| else: | |||
| import vertexai.generative_models as glm | |||
| b64 = self.image2base64(image) | |||
| # Create proper image part for Gemini | |||
| image_part = glm.Part.from_data( | |||
| data=base64.b64decode(b64), | |||
| mime_type="image/jpeg" | |||
| ) | |||
| image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg") | |||
| input = [prompt, image_part] | |||
| res = self.client.generate_content(input) | |||
| return res.text, res.usage_metadata.total_token_count | |||
| @@ -1137,29 +1163,19 @@ class GoogleCV(Base): | |||
| "data": b64, | |||
| }, | |||
| }, | |||
| { | |||
| "type": "text", | |||
| "text": prompt if prompt else vision_llm_describe_prompt() | |||
| } | |||
| {"type": "text", "text": prompt if prompt else vision_llm_describe_prompt()}, | |||
| ], | |||
| } | |||
| ] | |||
| response = self.client.messages.create( | |||
| model=self.model_name, | |||
| max_tokens=8192, | |||
| messages=vision_prompt | |||
| ) | |||
| response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=vision_prompt) | |||
| return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens | |||
| else: | |||
| import vertexai.generative_models as glm | |||
| b64 = self.image2base64(image) | |||
| vision_prompt = prompt if prompt else vision_llm_describe_prompt() | |||
| # Create proper image part for Gemini | |||
| image_part = glm.Part.from_data( | |||
| data=base64.b64decode(b64), | |||
| mime_type="image/jpeg" | |||
| ) | |||
| image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg") | |||
| input = [vision_prompt, image_part] | |||
| res = self.client.generate_content(input) | |||
| return res.text, res.usage_metadata.total_token_count | |||
| @@ -1180,25 +1196,17 @@ class GoogleCV(Base): | |||
| "data": image, | |||
| }, | |||
| }, | |||
| { | |||
| "type": "text", | |||
| "text": his["content"] | |||
| } | |||
| {"type": "text", "text": his["content"]}, | |||
| ] | |||
| response = self.client.messages.create( | |||
| model=self.model_name, | |||
| max_tokens=8192, | |||
| messages=history, | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7) | |||
| ) | |||
| response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)) | |||
| return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| else: | |||
| import vertexai.generative_models as glm | |||
| from transformers import GenerationConfig | |||
| if system: | |||
| history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] | |||
| try: | |||
| @@ -1210,20 +1218,15 @@ class GoogleCV(Base): | |||
| if his["role"] == "user": | |||
| his["parts"] = [his["content"]] | |||
| his.pop("content") | |||
| # Create proper image part for Gemini | |||
| img_bytes = base64.b64decode(image) | |||
| image_part = glm.Part.from_data( | |||
| data=img_bytes, | |||
| mime_type="image/jpeg" | |||
| ) | |||
| image_part = glm.Part.from_data(data=img_bytes, mime_type="image/jpeg") | |||
| history[-1]["parts"].append(image_part) | |||
| response = self.client.generate_content(history, generation_config=GenerationConfig( | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7))) | |||
| response = self.client.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))) | |||
| ans = response.text | |||
| return ans, response.usage_metadata.total_token_count | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| return "**ERROR**: " + str(e), 0 | |||
| @@ -13,28 +13,27 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import json | |||
| import logging | |||
| import os | |||
| import re | |||
| import threading | |||
| from abc import ABC | |||
| from urllib.parse import urljoin | |||
| import dashscope | |||
| import google.generativeai as genai | |||
| import numpy as np | |||
| import requests | |||
| from huggingface_hub import snapshot_download | |||
| from zhipuai import ZhipuAI | |||
| import os | |||
| from abc import ABC | |||
| from ollama import Client | |||
| import dashscope | |||
| from openai import OpenAI | |||
| import numpy as np | |||
| import asyncio | |||
| from zhipuai import ZhipuAI | |||
| from api import settings | |||
| from api.utils.file_utils import get_home_cache_dir | |||
| from api.utils.log_utils import log_exception | |||
| from rag.utils import num_tokens_from_string, truncate | |||
| import google.generativeai as genai | |||
| import json | |||
| class Base(ABC): | |||
| @@ -60,7 +59,8 @@ class Base(ABC): | |||
| class DefaultEmbedding(Base): | |||
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||
| _FACTORY_NAME = "BAAI" | |||
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |||
| _model = None | |||
| _model_name = "" | |||
| _model_lock = threading.Lock() | |||
| @@ -79,21 +79,22 @@ class DefaultEmbedding(Base): | |||
| """ | |||
| if not settings.LIGHTEN: | |||
| with DefaultEmbedding._model_lock: | |||
| from FlagEmbedding import FlagModel | |||
| import torch | |||
| from FlagEmbedding import FlagModel | |||
| if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: | |||
| try: | |||
| DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available()) | |||
| DefaultEmbedding._model = FlagModel( | |||
| os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available(), | |||
| ) | |||
| DefaultEmbedding._model_name = model_name | |||
| except Exception: | |||
| model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", | |||
| local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), | |||
| local_dir_use_symlinks=False) | |||
| DefaultEmbedding._model = FlagModel(model_dir, | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available()) | |||
| model_dir = snapshot_download( | |||
| repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False | |||
| ) | |||
| DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available()) | |||
| self._model = DefaultEmbedding._model | |||
| self._model_name = DefaultEmbedding._model_name | |||
| @@ -105,7 +106,7 @@ class DefaultEmbedding(Base): | |||
| token_count += num_tokens_from_string(t) | |||
| ress = [] | |||
| for i in range(0, len(texts), batch_size): | |||
| ress.extend(self._model.encode(texts[i:i + batch_size]).tolist()) | |||
| ress.extend(self._model.encode(texts[i : i + batch_size]).tolist()) | |||
| return np.array(ress), token_count | |||
| def encode_queries(self, text: str): | |||
| @@ -114,8 +115,9 @@ class DefaultEmbedding(Base): | |||
| class OpenAIEmbed(Base): | |||
| def __init__(self, key, model_name="text-embedding-ada-002", | |||
| base_url="https://api.openai.com/v1"): | |||
| _FACTORY_NAME = "OpenAI" | |||
| def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"): | |||
| if not base_url: | |||
| base_url = "https://api.openai.com/v1" | |||
| self.client = OpenAI(api_key=key, base_url=base_url) | |||
| @@ -128,8 +130,7 @@ class OpenAIEmbed(Base): | |||
| ress = [] | |||
| total_tokens = 0 | |||
| for i in range(0, len(texts), batch_size): | |||
| res = self.client.embeddings.create(input=texts[i:i + batch_size], | |||
| model=self.model_name) | |||
| res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) | |||
| try: | |||
| ress.extend([d.embedding for d in res.data]) | |||
| total_tokens += self.total_token_count(res) | |||
| @@ -138,12 +139,13 @@ class OpenAIEmbed(Base): | |||
| return np.array(ress), total_tokens | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings.create(input=[truncate(text, 8191)], | |||
| model=self.model_name) | |||
| res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name) | |||
| return np.array(res.data[0].embedding), self.total_token_count(res) | |||
| class LocalAIEmbed(Base): | |||
| _FACTORY_NAME = "LocalAI" | |||
| def __init__(self, key, model_name, base_url): | |||
| if not base_url: | |||
| raise ValueError("Local embedding model url cannot be None") | |||
| @@ -155,7 +157,7 @@ class LocalAIEmbed(Base): | |||
| batch_size = 16 | |||
| ress = [] | |||
| for i in range(0, len(texts), batch_size): | |||
| res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) | |||
| res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) | |||
| try: | |||
| ress.extend([d.embedding for d in res.data]) | |||
| except Exception as _e: | |||
| @@ -169,41 +171,42 @@ class LocalAIEmbed(Base): | |||
| class AzureEmbed(OpenAIEmbed): | |||
| _FACTORY_NAME = "Azure-OpenAI" | |||
| def __init__(self, key, model_name, **kwargs): | |||
| from openai.lib.azure import AzureOpenAI | |||
| api_key = json.loads(key).get('api_key', '') | |||
| api_version = json.loads(key).get('api_version', '2024-02-01') | |||
| api_key = json.loads(key).get("api_key", "") | |||
| api_version = json.loads(key).get("api_version", "2024-02-01") | |||
| self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) | |||
| self.model_name = model_name | |||
| class BaiChuanEmbed(OpenAIEmbed): | |||
| def __init__(self, key, | |||
| model_name='Baichuan-Text-Embedding', | |||
| base_url='https://api.baichuan-ai.com/v1'): | |||
| _FACTORY_NAME = "BaiChuan" | |||
| def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"): | |||
| if not base_url: | |||
| base_url = "https://api.baichuan-ai.com/v1" | |||
| super().__init__(key, model_name, base_url) | |||
| class QWenEmbed(Base): | |||
| _FACTORY_NAME = "Tongyi-Qianwen" | |||
| def __init__(self, key, model_name="text_embedding_v2", **kwargs): | |||
| self.key = key | |||
| self.model_name = model_name | |||
| def encode(self, texts: list): | |||
| import dashscope | |||
| batch_size = 4 | |||
| res = [] | |||
| token_count = 0 | |||
| texts = [truncate(t, 2048) for t in texts] | |||
| for i in range(0, len(texts), batch_size): | |||
| resp = dashscope.TextEmbedding.call( | |||
| model=self.model_name, | |||
| input=texts[i:i + batch_size], | |||
| api_key=self.key, | |||
| text_type="document" | |||
| ) | |||
| resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document") | |||
| try: | |||
| embds = [[] for _ in range(len(resp["output"]["embeddings"]))] | |||
| for e in resp["output"]["embeddings"]: | |||
| @@ -216,20 +219,16 @@ class QWenEmbed(Base): | |||
| return np.array(res), token_count | |||
| def encode_queries(self, text): | |||
| resp = dashscope.TextEmbedding.call( | |||
| model=self.model_name, | |||
| input=text[:2048], | |||
| api_key=self.key, | |||
| text_type="query" | |||
| ) | |||
| resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query") | |||
| try: | |||
| return np.array(resp["output"]["embeddings"][0] | |||
| ["embedding"]), self.total_token_count(resp) | |||
| return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp) | |||
| except Exception as _e: | |||
| log_exception(_e, resp) | |||
| class ZhipuEmbed(Base): | |||
| _FACTORY_NAME = "ZHIPU-AI" | |||
| def __init__(self, key, model_name="embedding-2", **kwargs): | |||
| self.client = ZhipuAI(api_key=key) | |||
| self.model_name = model_name | |||
| @@ -246,8 +245,7 @@ class ZhipuEmbed(Base): | |||
| texts = [truncate(t, MAX_LEN) for t in texts] | |||
| for txt in texts: | |||
| res = self.client.embeddings.create(input=txt, | |||
| model=self.model_name) | |||
| res = self.client.embeddings.create(input=txt, model=self.model_name) | |||
| try: | |||
| arr.append(res.data[0].embedding) | |||
| tks_num += self.total_token_count(res) | |||
| @@ -256,8 +254,7 @@ class ZhipuEmbed(Base): | |||
| return np.array(arr), tks_num | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings.create(input=text, | |||
| model=self.model_name) | |||
| res = self.client.embeddings.create(input=text, model=self.model_name) | |||
| try: | |||
| return np.array(res.data[0].embedding), self.total_token_count(res) | |||
| except Exception as _e: | |||
| @@ -265,18 +262,17 @@ class ZhipuEmbed(Base): | |||
| class OllamaEmbed(Base): | |||
| _FACTORY_NAME = "Ollama" | |||
| def __init__(self, key, model_name, **kwargs): | |||
| self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \ | |||
| Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"}) | |||
| self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"}) | |||
| self.model_name = model_name | |||
| def encode(self, texts: list): | |||
| arr = [] | |||
| tks_num = 0 | |||
| for txt in texts: | |||
| res = self.client.embeddings(prompt=txt, | |||
| model=self.model_name, | |||
| options={"use_mmap": True}) | |||
| res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}) | |||
| try: | |||
| arr.append(res["embedding"]) | |||
| except Exception as _e: | |||
| @@ -285,9 +281,7 @@ class OllamaEmbed(Base): | |||
| return np.array(arr), tks_num | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings(prompt=text, | |||
| model=self.model_name, | |||
| options={"use_mmap": True}) | |||
| res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}) | |||
| try: | |||
| return np.array(res["embedding"]), 128 | |||
| except Exception as _e: | |||
| @@ -295,27 +289,28 @@ class OllamaEmbed(Base): | |||
| class FastEmbed(DefaultEmbedding): | |||
| _FACTORY_NAME = "FastEmbed" | |||
| def __init__( | |||
| self, | |||
| key: str | None = None, | |||
| model_name: str = "BAAI/bge-small-en-v1.5", | |||
| cache_dir: str | None = None, | |||
| threads: int | None = None, | |||
| **kwargs, | |||
| self, | |||
| key: str | None = None, | |||
| model_name: str = "BAAI/bge-small-en-v1.5", | |||
| cache_dir: str | None = None, | |||
| threads: int | None = None, | |||
| **kwargs, | |||
| ): | |||
| if not settings.LIGHTEN: | |||
| with FastEmbed._model_lock: | |||
| from fastembed import TextEmbedding | |||
| if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: | |||
| try: | |||
| DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | |||
| DefaultEmbedding._model_name = model_name | |||
| except Exception: | |||
| cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5", | |||
| local_dir=os.path.join(get_home_cache_dir(), | |||
| re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), | |||
| local_dir_use_symlinks=False) | |||
| cache_dir = snapshot_download( | |||
| repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False | |||
| ) | |||
| DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | |||
| self._model = DefaultEmbedding._model | |||
| self._model_name = model_name | |||
| @@ -340,6 +335,8 @@ class FastEmbed(DefaultEmbedding): | |||
| class XinferenceEmbed(Base): | |||
| _FACTORY_NAME = "Xinference" | |||
| def __init__(self, key, model_name="", base_url=""): | |||
| base_url = urljoin(base_url, "v1") | |||
| self.client = OpenAI(api_key=key, base_url=base_url) | |||
| @@ -350,7 +347,7 @@ class XinferenceEmbed(Base): | |||
| ress = [] | |||
| total_tokens = 0 | |||
| for i in range(0, len(texts), batch_size): | |||
| res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) | |||
| res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) | |||
| try: | |||
| ress.extend([d.embedding for d in res.data]) | |||
| total_tokens += self.total_token_count(res) | |||
| @@ -359,8 +356,7 @@ class XinferenceEmbed(Base): | |||
| return np.array(ress), total_tokens | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings.create(input=[text], | |||
| model=self.model_name) | |||
| res = self.client.embeddings.create(input=[text], model=self.model_name) | |||
| try: | |||
| return np.array(res.data[0].embedding), self.total_token_count(res) | |||
| except Exception as _e: | |||
| @@ -368,20 +364,18 @@ class XinferenceEmbed(Base): | |||
| class YoudaoEmbed(Base): | |||
| _FACTORY_NAME = "Youdao" | |||
| _client = None | |||
| def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): | |||
| if not settings.LIGHTEN and not YoudaoEmbed._client: | |||
| from BCEmbedding import EmbeddingModel as qanthing | |||
| try: | |||
| logging.info("LOADING BCE...") | |||
| YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join( | |||
| get_home_cache_dir(), | |||
| "bce-embedding-base_v1")) | |||
| YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1")) | |||
| except Exception: | |||
| YoudaoEmbed._client = qanthing( | |||
| model_name_or_path=model_name.replace( | |||
| "maidalun1020", "InfiniFlow")) | |||
| YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow")) | |||
| def encode(self, texts: list): | |||
| batch_size = 10 | |||
| @@ -390,7 +384,7 @@ class YoudaoEmbed(Base): | |||
| for t in texts: | |||
| token_count += num_tokens_from_string(t) | |||
| for i in range(0, len(texts), batch_size): | |||
| embds = YoudaoEmbed._client.encode(texts[i:i + batch_size]) | |||
| embds = YoudaoEmbed._client.encode(texts[i : i + batch_size]) | |||
| res.extend(embds) | |||
| return np.array(res), token_count | |||
| @@ -400,14 +394,11 @@ class YoudaoEmbed(Base): | |||
| class JinaEmbed(Base): | |||
| def __init__(self, key, model_name="jina-embeddings-v3", | |||
| base_url="https://api.jina.ai/v1/embeddings"): | |||
| _FACTORY_NAME = "Jina" | |||
| def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"): | |||
| self.base_url = "https://api.jina.ai/v1/embeddings" | |||
| self.headers = { | |||
| "Content-Type": "application/json", | |||
| "Authorization": f"Bearer {key}" | |||
| } | |||
| self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} | |||
| self.model_name = model_name | |||
| def encode(self, texts: list): | |||
| @@ -416,11 +407,7 @@ class JinaEmbed(Base): | |||
| ress = [] | |||
| token_count = 0 | |||
| for i in range(0, len(texts), batch_size): | |||
| data = { | |||
| "model": self.model_name, | |||
| "input": texts[i:i + batch_size], | |||
| 'encoding_type': 'float' | |||
| } | |||
| data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"} | |||
| response = requests.post(self.base_url, headers=self.headers, json=data) | |||
| try: | |||
| res = response.json() | |||
| @@ -435,50 +422,12 @@ class JinaEmbed(Base): | |||
| return np.array(embds[0]), cnt | |||
| class InfinityEmbed(Base): | |||
| _model = None | |||
| def __init__( | |||
| self, | |||
| model_names: list[str] = ("BAAI/bge-small-en-v1.5",), | |||
| engine_kwargs: dict = {}, | |||
| key = None, | |||
| ): | |||
| from infinity_emb import EngineArgs | |||
| from infinity_emb.engine import AsyncEngineArray | |||
| self._default_model = model_names[0] | |||
| self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names]) | |||
| async def _embed(self, sentences: list[str], model_name: str = ""): | |||
| if not model_name: | |||
| model_name = self._default_model | |||
| engine = self.engine_array[model_name] | |||
| was_already_running = engine.is_running | |||
| if not was_already_running: | |||
| await engine.astart() | |||
| embeddings, usage = await engine.embed(sentences=sentences) | |||
| if not was_already_running: | |||
| await engine.astop() | |||
| return embeddings, usage | |||
| def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]: | |||
| # Using the internal tokenizer to encode the texts and get the total | |||
| # number of tokens | |||
| embeddings, usage = asyncio.run(self._embed(texts, model_name)) | |||
| return np.array(embeddings), usage | |||
| def encode_queries(self, text: str) -> tuple[np.ndarray, int]: | |||
| # Using the internal tokenizer to encode the texts and get the total | |||
| # number of tokens | |||
| return self.encode([text]) | |||
| class MistralEmbed(Base): | |||
| def __init__(self, key, model_name="mistral-embed", | |||
| base_url=None): | |||
| _FACTORY_NAME = "Mistral" | |||
| def __init__(self, key, model_name="mistral-embed", base_url=None): | |||
| from mistralai.client import MistralClient | |||
| self.client = MistralClient(api_key=key) | |||
| self.model_name = model_name | |||
| @@ -488,8 +437,7 @@ class MistralEmbed(Base): | |||
| ress = [] | |||
| token_count = 0 | |||
| for i in range(0, len(texts), batch_size): | |||
| res = self.client.embeddings(input=texts[i:i + batch_size], | |||
| model=self.model_name) | |||
| res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name) | |||
| try: | |||
| ress.extend([d.embedding for d in res.data]) | |||
| token_count += self.total_token_count(res) | |||
| @@ -498,8 +446,7 @@ class MistralEmbed(Base): | |||
| return np.array(ress), token_count | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings(input=[truncate(text, 8196)], | |||
| model=self.model_name) | |||
| res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name) | |||
| try: | |||
| return np.array(res.data[0].embedding), self.total_token_count(res) | |||
| except Exception as _e: | |||
| @@ -507,30 +454,31 @@ class MistralEmbed(Base): | |||
| class BedrockEmbed(Base): | |||
| def __init__(self, key, model_name, | |||
| **kwargs): | |||
| _FACTORY_NAME = "Bedrock" | |||
| def __init__(self, key, model_name, **kwargs): | |||
| import boto3 | |||
| self.bedrock_ak = json.loads(key).get('bedrock_ak', '') | |||
| self.bedrock_sk = json.loads(key).get('bedrock_sk', '') | |||
| self.bedrock_region = json.loads(key).get('bedrock_region', '') | |||
| self.bedrock_ak = json.loads(key).get("bedrock_ak", "") | |||
| self.bedrock_sk = json.loads(key).get("bedrock_sk", "") | |||
| self.bedrock_region = json.loads(key).get("bedrock_region", "") | |||
| self.model_name = model_name | |||
| if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '': | |||
| if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "": | |||
| # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) | |||
| self.client = boto3.client('bedrock-runtime') | |||
| self.client = boto3.client("bedrock-runtime") | |||
| else: | |||
| self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, | |||
| aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) | |||
| self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) | |||
| def encode(self, texts: list): | |||
| texts = [truncate(t, 8196) for t in texts] | |||
| embeddings = [] | |||
| token_count = 0 | |||
| for text in texts: | |||
| if self.model_name.split('.')[0] == 'amazon': | |||
| if self.model_name.split(".")[0] == "amazon": | |||
| body = {"inputText": text} | |||
| elif self.model_name.split('.')[0] == 'cohere': | |||
| body = {"texts": [text], "input_type": 'search_document'} | |||
| elif self.model_name.split(".")[0] == "cohere": | |||
| body = {"texts": [text], "input_type": "search_document"} | |||
| response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) | |||
| try: | |||
| @@ -545,10 +493,10 @@ class BedrockEmbed(Base): | |||
| def encode_queries(self, text): | |||
| embeddings = [] | |||
| token_count = num_tokens_from_string(text) | |||
| if self.model_name.split('.')[0] == 'amazon': | |||
| if self.model_name.split(".")[0] == "amazon": | |||
| body = {"inputText": truncate(text, 8196)} | |||
| elif self.model_name.split('.')[0] == 'cohere': | |||
| body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'} | |||
| elif self.model_name.split(".")[0] == "cohere": | |||
| body = {"texts": [truncate(text, 8196)], "input_type": "search_query"} | |||
| response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) | |||
| try: | |||
| @@ -561,11 +509,12 @@ class BedrockEmbed(Base): | |||
| class GeminiEmbed(Base): | |||
| def __init__(self, key, model_name='models/text-embedding-004', | |||
| **kwargs): | |||
| _FACTORY_NAME = "Gemini" | |||
| def __init__(self, key, model_name="models/text-embedding-004", **kwargs): | |||
| self.key = key | |||
| self.model_name = 'models/' + model_name | |||
| self.model_name = "models/" + model_name | |||
| def encode(self, texts: list): | |||
| texts = [truncate(t, 2048) for t in texts] | |||
| token_count = sum(num_tokens_from_string(text) for text in texts) | |||
| @@ -573,35 +522,27 @@ class GeminiEmbed(Base): | |||
| batch_size = 16 | |||
| ress = [] | |||
| for i in range(0, len(texts), batch_size): | |||
| result = genai.embed_content( | |||
| model=self.model_name, | |||
| content=texts[i: i + batch_size], | |||
| task_type="retrieval_document", | |||
| title="Embedding of single string") | |||
| result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string") | |||
| try: | |||
| ress.extend(result['embedding']) | |||
| ress.extend(result["embedding"]) | |||
| except Exception as _e: | |||
| log_exception(_e, result) | |||
| return np.array(ress),token_count | |||
| return np.array(ress), token_count | |||
| def encode_queries(self, text): | |||
| genai.configure(api_key=self.key) | |||
| result = genai.embed_content( | |||
| model=self.model_name, | |||
| content=truncate(text,2048), | |||
| task_type="retrieval_document", | |||
| title="Embedding of single string") | |||
| result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string") | |||
| token_count = num_tokens_from_string(text) | |||
| try: | |||
| return np.array(result['embedding']), token_count | |||
| return np.array(result["embedding"]), token_count | |||
| except Exception as _e: | |||
| log_exception(_e, result) | |||
| class NvidiaEmbed(Base): | |||
| def __init__( | |||
| self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings" | |||
| ): | |||
| _FACTORY_NAME = "NVIDIA" | |||
| def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"): | |||
| if not base_url: | |||
| base_url = "https://integrate.api.nvidia.com/v1/embeddings" | |||
| self.api_key = key | |||
| @@ -645,6 +586,8 @@ class NvidiaEmbed(Base): | |||
| class LmStudioEmbed(LocalAIEmbed): | |||
| _FACTORY_NAME = "LM-Studio" | |||
| def __init__(self, key, model_name, base_url): | |||
| if not base_url: | |||
| raise ValueError("Local llm url cannot be None") | |||
| @@ -654,6 +597,8 @@ class LmStudioEmbed(LocalAIEmbed): | |||
| class OpenAI_APIEmbed(OpenAIEmbed): | |||
| _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"] | |||
| def __init__(self, key, model_name, base_url): | |||
| if not base_url: | |||
| raise ValueError("url cannot be None") | |||
| @@ -663,6 +608,8 @@ class OpenAI_APIEmbed(OpenAIEmbed): | |||
| class CoHereEmbed(Base): | |||
| _FACTORY_NAME = "Cohere" | |||
| def __init__(self, key, model_name, base_url=None): | |||
| from cohere import Client | |||
| @@ -701,6 +648,8 @@ class CoHereEmbed(Base): | |||
| class TogetherAIEmbed(OpenAIEmbed): | |||
| _FACTORY_NAME = "TogetherAI" | |||
| def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): | |||
| if not base_url: | |||
| base_url = "https://api.together.xyz/v1" | |||
| @@ -708,6 +657,8 @@ class TogetherAIEmbed(OpenAIEmbed): | |||
| class PerfXCloudEmbed(OpenAIEmbed): | |||
| _FACTORY_NAME = "PerfXCloud" | |||
| def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"): | |||
| if not base_url: | |||
| base_url = "https://cloud.perfxlab.cn/v1" | |||
| @@ -715,6 +666,8 @@ class PerfXCloudEmbed(OpenAIEmbed): | |||
| class UpstageEmbed(OpenAIEmbed): | |||
| _FACTORY_NAME = "Upstage" | |||
| def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"): | |||
| if not base_url: | |||
| base_url = "https://api.upstage.ai/v1/solar" | |||
| @@ -722,6 +675,8 @@ class UpstageEmbed(OpenAIEmbed): | |||
| class SILICONFLOWEmbed(Base): | |||
| _FACTORY_NAME = "SILICONFLOW" | |||
| def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"): | |||
| if not base_url: | |||
| base_url = "https://api.siliconflow.cn/v1/embeddings" | |||
| @@ -769,6 +724,8 @@ class SILICONFLOWEmbed(Base): | |||
| class ReplicateEmbed(Base): | |||
| _FACTORY_NAME = "Replicate" | |||
| def __init__(self, key, model_name, base_url=None): | |||
| from replicate.client import Client | |||
| @@ -790,6 +747,8 @@ class ReplicateEmbed(Base): | |||
| class BaiduYiyanEmbed(Base): | |||
| _FACTORY_NAME = "BaiduYiyan" | |||
| def __init__(self, key, model_name, base_url=None): | |||
| import qianfan | |||
| @@ -821,6 +780,8 @@ class BaiduYiyanEmbed(Base): | |||
| class VoyageEmbed(Base): | |||
| _FACTORY_NAME = "Voyage AI" | |||
| def __init__(self, key, model_name, base_url=None): | |||
| import voyageai | |||
| @@ -832,9 +793,7 @@ class VoyageEmbed(Base): | |||
| ress = [] | |||
| token_count = 0 | |||
| for i in range(0, len(texts), batch_size): | |||
| res = self.client.embed( | |||
| texts=texts[i : i + batch_size], model=self.model_name, input_type="document" | |||
| ) | |||
| res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document") | |||
| try: | |||
| ress.extend(res.embeddings) | |||
| token_count += res.total_tokens | |||
| @@ -843,9 +802,7 @@ class VoyageEmbed(Base): | |||
| return np.array(ress), token_count | |||
| def encode_queries(self, text): | |||
| res = self.client.embed( | |||
| texts=text, model=self.model_name, input_type="query" | |||
| ) | |||
| res = self.client.embed(texts=text, model=self.model_name, input_type="query") | |||
| try: | |||
| return np.array(res.embeddings)[0], res.total_tokens | |||
| except Exception as _e: | |||
| @@ -853,6 +810,8 @@ class VoyageEmbed(Base): | |||
| class HuggingFaceEmbed(Base): | |||
| _FACTORY_NAME = "HuggingFace" | |||
| def __init__(self, key, model_name, base_url=None): | |||
| if not model_name: | |||
| raise ValueError("Model name cannot be None") | |||
| @@ -863,11 +822,7 @@ class HuggingFaceEmbed(Base): | |||
| def encode(self, texts: list): | |||
| embeddings = [] | |||
| for text in texts: | |||
| response = requests.post( | |||
| f"{self.base_url}/embed", | |||
| json={"inputs": text}, | |||
| headers={'Content-Type': 'application/json'} | |||
| ) | |||
| response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}) | |||
| if response.status_code == 200: | |||
| embedding = response.json() | |||
| embeddings.append(embedding[0]) | |||
| @@ -876,11 +831,7 @@ class HuggingFaceEmbed(Base): | |||
| return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) | |||
| def encode_queries(self, text): | |||
| response = requests.post( | |||
| f"{self.base_url}/embed", | |||
| json={"inputs": text}, | |||
| headers={'Content-Type': 'application/json'} | |||
| ) | |||
| response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}) | |||
| if response.status_code == 200: | |||
| embedding = response.json() | |||
| return np.array(embedding[0]), num_tokens_from_string(text) | |||
| @@ -889,15 +840,19 @@ class HuggingFaceEmbed(Base): | |||
| class VolcEngineEmbed(OpenAIEmbed): | |||
| _FACTORY_NAME = "VolcEngine" | |||
| def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): | |||
| if not base_url: | |||
| base_url = "https://ark.cn-beijing.volces.com/api/v3" | |||
| ark_api_key = json.loads(key).get('ark_api_key', '') | |||
| model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '') | |||
| super().__init__(ark_api_key,model_name,base_url) | |||
| ark_api_key = json.loads(key).get("ark_api_key", "") | |||
| model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "") | |||
| super().__init__(ark_api_key, model_name, base_url) | |||
| class GPUStackEmbed(OpenAIEmbed): | |||
| _FACTORY_NAME = "GPUStack" | |||
| def __init__(self, key, model_name, base_url): | |||
| if not base_url: | |||
| raise ValueError("url cannot be None") | |||
| @@ -908,6 +863,8 @@ class GPUStackEmbed(OpenAIEmbed): | |||
| class NovitaEmbed(SILICONFLOWEmbed): | |||
| _FACTORY_NAME = "NovitaAI" | |||
| def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"): | |||
| if not base_url: | |||
| base_url = "https://api.novita.ai/v3/openai/embeddings" | |||
| @@ -915,7 +872,9 @@ class NovitaEmbed(SILICONFLOWEmbed): | |||
| class GiteeEmbed(SILICONFLOWEmbed): | |||
| _FACTORY_NAME = "GiteeAI" | |||
| def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"): | |||
| if not base_url: | |||
| base_url = "https://ai.gitee.com/v1/embeddings" | |||
| super().__init__(key, model_name, base_url) | |||
| super().__init__(key, model_name, base_url) | |||
| @@ -13,24 +13,24 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import json | |||
| import os | |||
| import re | |||
| import threading | |||
| from abc import ABC | |||
| from collections.abc import Iterable | |||
| from urllib.parse import urljoin | |||
| import requests | |||
| import httpx | |||
| from huggingface_hub import snapshot_download | |||
| import os | |||
| from abc import ABC | |||
| import numpy as np | |||
| import requests | |||
| from huggingface_hub import snapshot_download | |||
| from yarl import URL | |||
| from api import settings | |||
| from api.utils.file_utils import get_home_cache_dir | |||
| from api.utils.log_utils import log_exception | |||
| from rag.utils import num_tokens_from_string, truncate | |||
| import json | |||
| def sigmoid(x): | |||
| @@ -57,6 +57,7 @@ class Base(ABC): | |||
| class DefaultRerank(Base): | |||
| _FACTORY_NAME = "BAAI" | |||
| _model = None | |||
| _model_lock = threading.Lock() | |||
| @@ -75,17 +76,13 @@ class DefaultRerank(Base): | |||
| if not settings.LIGHTEN and not DefaultRerank._model: | |||
| import torch | |||
| from FlagEmbedding import FlagReranker | |||
| with DefaultRerank._model_lock: | |||
| if not DefaultRerank._model: | |||
| try: | |||
| DefaultRerank._model = FlagReranker( | |||
| os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), | |||
| use_fp16=torch.cuda.is_available()) | |||
| DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available()) | |||
| except Exception: | |||
| model_dir = snapshot_download(repo_id=model_name, | |||
| local_dir=os.path.join(get_home_cache_dir(), | |||
| re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), | |||
| local_dir_use_symlinks=False) | |||
| model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False) | |||
| DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available()) | |||
| self._model = DefaultRerank._model | |||
| self._dynamic_batch_size = 8 | |||
| @@ -94,6 +91,7 @@ class DefaultRerank(Base): | |||
| def torch_empty_cache(self): | |||
| try: | |||
| import torch | |||
| torch.cuda.empty_cache() | |||
| except Exception as e: | |||
| print(f"Error emptying cache: {e}") | |||
| @@ -112,7 +110,7 @@ class DefaultRerank(Base): | |||
| while retry_count < max_retries: | |||
| try: | |||
| # call subclass implemented batch processing calculation | |||
| batch_scores = self._compute_batch_scores(pairs[i:i + current_batch]) | |||
| batch_scores = self._compute_batch_scores(pairs[i : i + current_batch]) | |||
| res.extend(batch_scores) | |||
| i += current_batch | |||
| self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8) | |||
| @@ -152,23 +150,16 @@ class DefaultRerank(Base): | |||
| class JinaRerank(Base): | |||
| def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", | |||
| base_url="https://api.jina.ai/v1/rerank"): | |||
| _FACTORY_NAME = "Jina" | |||
| def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"): | |||
| self.base_url = "https://api.jina.ai/v1/rerank" | |||
| self.headers = { | |||
| "Content-Type": "application/json", | |||
| "Authorization": f"Bearer {key}" | |||
| } | |||
| self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} | |||
| self.model_name = model_name | |||
| def similarity(self, query: str, texts: list): | |||
| texts = [truncate(t, 8196) for t in texts] | |||
| data = { | |||
| "model": self.model_name, | |||
| "query": query, | |||
| "documents": texts, | |||
| "top_n": len(texts) | |||
| } | |||
| data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)} | |||
| res = requests.post(self.base_url, headers=self.headers, json=data).json() | |||
| rank = np.zeros(len(texts), dtype=float) | |||
| try: | |||
| @@ -180,22 +171,20 @@ class JinaRerank(Base): | |||
| class YoudaoRerank(DefaultRerank): | |||
| _FACTORY_NAME = "Youdao" | |||
| _model = None | |||
| _model_lock = threading.Lock() | |||
| def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): | |||
| if not settings.LIGHTEN and not YoudaoRerank._model: | |||
| from BCEmbedding import RerankerModel | |||
| with YoudaoRerank._model_lock: | |||
| if not YoudaoRerank._model: | |||
| try: | |||
| YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( | |||
| get_home_cache_dir(), | |||
| re.sub(r"^[a-zA-Z0-9]+/", "", model_name))) | |||
| YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name))) | |||
| except Exception: | |||
| YoudaoRerank._model = RerankerModel( | |||
| model_name_or_path=model_name.replace( | |||
| "maidalun1020", "InfiniFlow")) | |||
| YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow")) | |||
| self._model = YoudaoRerank._model | |||
| self._dynamic_batch_size = 8 | |||
| @@ -212,6 +201,8 @@ class YoudaoRerank(DefaultRerank): | |||
| class XInferenceRerank(Base): | |||
| _FACTORY_NAME = "Xinference" | |||
| def __init__(self, key="x", model_name="", base_url=""): | |||
| if base_url.find("/v1") == -1: | |||
| base_url = urljoin(base_url, "/v1/rerank") | |||
| @@ -219,10 +210,7 @@ class XInferenceRerank(Base): | |||
| base_url = urljoin(base_url, "/v1/rerank") | |||
| self.model_name = model_name | |||
| self.base_url = base_url | |||
| self.headers = { | |||
| "Content-Type": "application/json", | |||
| "accept": "application/json" | |||
| } | |||
| self.headers = {"Content-Type": "application/json", "accept": "application/json"} | |||
| if key and key != "x": | |||
| self.headers["Authorization"] = f"Bearer {key}" | |||
| @@ -233,13 +221,7 @@ class XInferenceRerank(Base): | |||
| token_count = 0 | |||
| for _, t in pairs: | |||
| token_count += num_tokens_from_string(t) | |||
| data = { | |||
| "model": self.model_name, | |||
| "query": query, | |||
| "return_documents": "true", | |||
| "return_len": "true", | |||
| "documents": texts | |||
| } | |||
| data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts} | |||
| res = requests.post(self.base_url, headers=self.headers, json=data).json() | |||
| rank = np.zeros(len(texts), dtype=float) | |||
| try: | |||
| @@ -251,15 +233,14 @@ class XInferenceRerank(Base): | |||
| class LocalAIRerank(Base): | |||
| _FACTORY_NAME = "LocalAI" | |||
| def __init__(self, key, model_name, base_url): | |||
| if base_url.find("/rerank") == -1: | |||
| self.base_url = urljoin(base_url, "/rerank") | |||
| else: | |||
| self.base_url = base_url | |||
| self.headers = { | |||
| "Content-Type": "application/json", | |||
| "Authorization": f"Bearer {key}" | |||
| } | |||
| self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} | |||
| self.model_name = model_name.split("___")[0] | |||
| def similarity(self, query: str, texts: list): | |||
| @@ -296,16 +277,15 @@ class LocalAIRerank(Base): | |||
| class NvidiaRerank(Base): | |||
| def __init__( | |||
| self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/" | |||
| ): | |||
| _FACTORY_NAME = "NVIDIA" | |||
| def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"): | |||
| if not base_url: | |||
| base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/" | |||
| self.model_name = model_name | |||
| if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3": | |||
| self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking" | |||
| ) | |||
| self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking") | |||
| if self.model_name == "nvidia/rerank-qa-mistral-4b": | |||
| self.base_url = urljoin(base_url, "reranking") | |||
| @@ -318,9 +298,7 @@ class NvidiaRerank(Base): | |||
| } | |||
| def similarity(self, query: str, texts: list): | |||
| token_count = num_tokens_from_string(query) + sum( | |||
| [num_tokens_from_string(t) for t in texts] | |||
| ) | |||
| token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts]) | |||
| data = { | |||
| "model": self.model_name, | |||
| "query": {"text": query}, | |||
| @@ -339,6 +317,8 @@ class NvidiaRerank(Base): | |||
| class LmStudioRerank(Base): | |||
| _FACTORY_NAME = "LM-Studio" | |||
| def __init__(self, key, model_name, base_url): | |||
| pass | |||
| @@ -347,15 +327,14 @@ class LmStudioRerank(Base): | |||
| class OpenAI_APIRerank(Base): | |||
| _FACTORY_NAME = "OpenAI-API-Compatible" | |||
| def __init__(self, key, model_name, base_url): | |||
| if base_url.find("/rerank") == -1: | |||
| self.base_url = urljoin(base_url, "/rerank") | |||
| else: | |||
| self.base_url = base_url | |||
| self.headers = { | |||
| "Content-Type": "application/json", | |||
| "Authorization": f"Bearer {key}" | |||
| } | |||
| self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} | |||
| self.model_name = model_name.split("___")[0] | |||
| def similarity(self, query: str, texts: list): | |||
| @@ -392,6 +371,8 @@ class OpenAI_APIRerank(Base): | |||
| class CoHereRerank(Base): | |||
| _FACTORY_NAME = ["Cohere", "VLLM"] | |||
| def __init__(self, key, model_name, base_url=None): | |||
| from cohere import Client | |||
| @@ -399,9 +380,7 @@ class CoHereRerank(Base): | |||
| self.model_name = model_name.split("___")[0] | |||
| def similarity(self, query: str, texts: list): | |||
| token_count = num_tokens_from_string(query) + sum( | |||
| [num_tokens_from_string(t) for t in texts] | |||
| ) | |||
| token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts]) | |||
| res = self.client.rerank( | |||
| model=self.model_name, | |||
| query=query, | |||
| @@ -419,6 +398,8 @@ class CoHereRerank(Base): | |||
| class TogetherAIRerank(Base): | |||
| _FACTORY_NAME = "TogetherAI" | |||
| def __init__(self, key, model_name, base_url): | |||
| pass | |||
| @@ -427,9 +408,9 @@ class TogetherAIRerank(Base): | |||
| class SILICONFLOWRerank(Base): | |||
| def __init__( | |||
| self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank" | |||
| ): | |||
| _FACTORY_NAME = "SILICONFLOW" | |||
| def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"): | |||
| if not base_url: | |||
| base_url = "https://api.siliconflow.cn/v1/rerank" | |||
| self.model_name = model_name | |||
| @@ -450,9 +431,7 @@ class SILICONFLOWRerank(Base): | |||
| "max_chunks_per_doc": 1024, | |||
| "overlap_tokens": 80, | |||
| } | |||
| response = requests.post( | |||
| self.base_url, json=payload, headers=self.headers | |||
| ).json() | |||
| response = requests.post(self.base_url, json=payload, headers=self.headers).json() | |||
| rank = np.zeros(len(texts), dtype=float) | |||
| try: | |||
| for d in response["results"]: | |||
| @@ -466,6 +445,8 @@ class SILICONFLOWRerank(Base): | |||
| class BaiduYiyanRerank(Base): | |||
| _FACTORY_NAME = "BaiduYiyan" | |||
| def __init__(self, key, model_name, base_url=None): | |||
| from qianfan.resources import Reranker | |||
| @@ -492,6 +473,8 @@ class BaiduYiyanRerank(Base): | |||
| class VoyageRerank(Base): | |||
| _FACTORY_NAME = "Voyage AI" | |||
| def __init__(self, key, model_name, base_url=None): | |||
| import voyageai | |||
| @@ -502,9 +485,7 @@ class VoyageRerank(Base): | |||
| rank = np.zeros(len(texts), dtype=float) | |||
| if not texts: | |||
| return rank, 0 | |||
| res = self.client.rerank( | |||
| query=query, documents=texts, model=self.model_name, top_k=len(texts) | |||
| ) | |||
| res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts)) | |||
| try: | |||
| for r in res.results: | |||
| rank[r.index] = r.relevance_score | |||
| @@ -514,22 +495,20 @@ class VoyageRerank(Base): | |||
| class QWenRerank(Base): | |||
| def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs): | |||
| _FACTORY_NAME = "Tongyi-Qianwen" | |||
| def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs): | |||
| import dashscope | |||
| self.api_key = key | |||
| self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name | |||
| def similarity(self, query: str, texts: list): | |||
| import dashscope | |||
| from http import HTTPStatus | |||
| resp = dashscope.TextReRank.call( | |||
| api_key=self.api_key, | |||
| model=self.model_name, | |||
| query=query, | |||
| documents=texts, | |||
| top_n=len(texts), | |||
| return_documents=False | |||
| ) | |||
| import dashscope | |||
| resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False) | |||
| rank = np.zeros(len(texts), dtype=float) | |||
| if resp.status_code == HTTPStatus.OK: | |||
| try: | |||
| @@ -543,6 +522,8 @@ class QWenRerank(Base): | |||
| class HuggingfaceRerank(DefaultRerank): | |||
| _FACTORY_NAME = "HuggingFace" | |||
| @staticmethod | |||
| def post(query: str, texts: list, url="127.0.0.1"): | |||
| exc = None | |||
| @@ -550,9 +531,9 @@ class HuggingfaceRerank(DefaultRerank): | |||
| batch_size = 8 | |||
| for i in range(0, len(texts), batch_size): | |||
| try: | |||
| res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"}, | |||
| json={"query": query, "texts": texts[i: i + batch_size], | |||
| "raw_scores": False, "truncate": True}) | |||
| res = requests.post( | |||
| f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True} | |||
| ) | |||
| for o in res.json(): | |||
| scores[o["index"] + i] = o["score"] | |||
| @@ -577,9 +558,9 @@ class HuggingfaceRerank(DefaultRerank): | |||
| class GPUStackRerank(Base): | |||
| def __init__( | |||
| self, key, model_name, base_url | |||
| ): | |||
| _FACTORY_NAME = "GPUStack" | |||
| def __init__(self, key, model_name, base_url): | |||
| if not base_url: | |||
| raise ValueError("url cannot be None") | |||
| @@ -600,9 +581,7 @@ class GPUStackRerank(Base): | |||
| } | |||
| try: | |||
| response = requests.post( | |||
| self.base_url, json=payload, headers=self.headers | |||
| ) | |||
| response = requests.post(self.base_url, json=payload, headers=self.headers) | |||
| response.raise_for_status() | |||
| response_json = response.json() | |||
| @@ -623,11 +602,12 @@ class GPUStackRerank(Base): | |||
| ) | |||
| except httpx.HTTPStatusError as e: | |||
| raise ValueError( | |||
| f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}") | |||
| raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}") | |||
| class NovitaRerank(JinaRerank): | |||
| _FACTORY_NAME = "NovitaAI" | |||
| def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"): | |||
| if not base_url: | |||
| base_url = "https://api.novita.ai/v3/openai/rerank" | |||
| @@ -635,7 +615,9 @@ class NovitaRerank(JinaRerank): | |||
| class GiteeRerank(JinaRerank): | |||
| _FACTORY_NAME = "GiteeAI" | |||
| def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"): | |||
| if not base_url: | |||
| base_url = "https://ai.gitee.com/v1/rerank" | |||
| super().__init__(key, model_name, base_url) | |||
| super().__init__(key, model_name, base_url) | |||
| @@ -13,16 +13,18 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| import requests | |||
| from openai.lib.azure import AzureOpenAI | |||
| import base64 | |||
| import io | |||
| import json | |||
| import os | |||
| import re | |||
| from abc import ABC | |||
| import requests | |||
| from openai import OpenAI | |||
| import json | |||
| from openai.lib.azure import AzureOpenAI | |||
| from rag.utils import num_tokens_from_string | |||
| import base64 | |||
| import re | |||
| class Base(ABC): | |||
| @@ -30,11 +32,7 @@ class Base(ABC): | |||
| pass | |||
| def transcription(self, audio, **kwargs): | |||
| transcription = self.client.audio.transcriptions.create( | |||
| model=self.model_name, | |||
| file=audio, | |||
| response_format="text" | |||
| ) | |||
| transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio, response_format="text") | |||
| return transcription.text.strip(), num_tokens_from_string(transcription.text.strip()) | |||
| def audio2base64(self, audio): | |||
| @@ -46,6 +44,8 @@ class Base(ABC): | |||
| class GPTSeq2txt(Base): | |||
| _FACTORY_NAME = "OpenAI" | |||
| def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"): | |||
| if not base_url: | |||
| base_url = "https://api.openai.com/v1" | |||
| @@ -54,31 +54,34 @@ class GPTSeq2txt(Base): | |||
| class QWenSeq2txt(Base): | |||
| _FACTORY_NAME = "Tongyi-Qianwen" | |||
| def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs): | |||
| import dashscope | |||
| dashscope.api_key = key | |||
| self.model_name = model_name | |||
| def transcription(self, audio, format): | |||
| from http import HTTPStatus | |||
| from dashscope.audio.asr import Recognition | |||
| recognition = Recognition(model=self.model_name, | |||
| format=format, | |||
| sample_rate=16000, | |||
| callback=None) | |||
| recognition = Recognition(model=self.model_name, format=format, sample_rate=16000, callback=None) | |||
| result = recognition.call(audio) | |||
| ans = "" | |||
| if result.status_code == HTTPStatus.OK: | |||
| for sentence in result.get_sentence(): | |||
| ans += sentence.text.decode('utf-8') + '\n' | |||
| ans += sentence.text.decode("utf-8") + "\n" | |||
| return ans, num_tokens_from_string(ans) | |||
| return "**ERROR**: " + result.message, 0 | |||
| class AzureSeq2txt(Base): | |||
| _FACTORY_NAME = "Azure-OpenAI" | |||
| def __init__(self, key, model_name, lang="Chinese", **kwargs): | |||
| self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") | |||
| self.model_name = model_name | |||
| @@ -86,43 +89,33 @@ class AzureSeq2txt(Base): | |||
| class XinferenceSeq2txt(Base): | |||
| _FACTORY_NAME = "Xinference" | |||
| def __init__(self, key, model_name="whisper-small", **kwargs): | |||
| self.base_url = kwargs.get('base_url', None) | |||
| self.base_url = kwargs.get("base_url", None) | |||
| self.model_name = model_name | |||
| self.key = key | |||
| def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7): | |||
| if isinstance(audio, str): | |||
| audio_file = open(audio, 'rb') | |||
| audio_file = open(audio, "rb") | |||
| audio_data = audio_file.read() | |||
| audio_file_name = audio.split("/")[-1] | |||
| else: | |||
| audio_data = audio | |||
| audio_file_name = "audio.wav" | |||
| payload = { | |||
| "model": self.model_name, | |||
| "language": language, | |||
| "prompt": prompt, | |||
| "response_format": response_format, | |||
| "temperature": temperature | |||
| } | |||
| payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature} | |||
| files = { | |||
| "file": (audio_file_name, audio_data, 'audio/wav') | |||
| } | |||
| files = {"file": (audio_file_name, audio_data, "audio/wav")} | |||
| try: | |||
| response = requests.post( | |||
| f"{self.base_url}/v1/audio/transcriptions", | |||
| files=files, | |||
| data=payload | |||
| ) | |||
| response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload) | |||
| response.raise_for_status() | |||
| result = response.json() | |||
| if 'text' in result: | |||
| transcription_text = result['text'].strip() | |||
| if "text" in result: | |||
| transcription_text = result["text"].strip() | |||
| return transcription_text, num_tokens_from_string(transcription_text) | |||
| else: | |||
| return "**ERROR**: Failed to retrieve transcription.", 0 | |||
| @@ -132,11 +125,11 @@ class XinferenceSeq2txt(Base): | |||
| class TencentCloudSeq2txt(Base): | |||
| def __init__( | |||
| self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com" | |||
| ): | |||
| from tencentcloud.common import credential | |||
| _FACTORY_NAME = "Tencent Cloud" | |||
| def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"): | |||
| from tencentcloud.asr.v20190614 import asr_client | |||
| from tencentcloud.common import credential | |||
| key = json.loads(key) | |||
| sid = key.get("tencent_cloud_sid", "") | |||
| @@ -146,11 +139,12 @@ class TencentCloudSeq2txt(Base): | |||
| self.model_name = model_name | |||
| def transcription(self, audio, max_retries=60, retry_interval=5): | |||
| import time | |||
| from tencentcloud.asr.v20190614 import models | |||
| from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( | |||
| TencentCloudSDKException, | |||
| ) | |||
| from tencentcloud.asr.v20190614 import models | |||
| import time | |||
| b64 = self.audio2base64(audio) | |||
| try: | |||
| @@ -174,9 +168,7 @@ class TencentCloudSeq2txt(Base): | |||
| while retries < max_retries: | |||
| resp = self.client.DescribeTaskStatus(req) | |||
| if resp.Data.StatusStr == "success": | |||
| text = re.sub( | |||
| r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result | |||
| ).strip() | |||
| text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip() | |||
| return text, num_tokens_from_string(text) | |||
| elif resp.Data.StatusStr == "failed": | |||
| return ( | |||
| @@ -195,6 +187,8 @@ class TencentCloudSeq2txt(Base): | |||
| class GPUStackSeq2txt(Base): | |||
| _FACTORY_NAME = "GPUStack" | |||
| def __init__(self, key, model_name, base_url): | |||
| if not base_url: | |||
| raise ValueError("url cannot be None") | |||
| @@ -206,8 +200,11 @@ class GPUStackSeq2txt(Base): | |||
| class GiteeSeq2txt(Base): | |||
| _FACTORY_NAME = "GiteeAI" | |||
| def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"): | |||
| if not base_url: | |||
| base_url = "https://ai.gitee.com/v1/" | |||
| self.client = OpenAI(api_key=key, base_url=base_url) | |||
| self.model_name = model_name | |||
| self.model_name = model_name | |||
| @@ -70,10 +70,12 @@ class Base(ABC): | |||
| pass | |||
| def normalize_text(self, text): | |||
| return re.sub(r'(\*\*|##\d+\$\$|#)', '', text) | |||
| return re.sub(r"(\*\*|##\d+\$\$|#)", "", text) | |||
| class FishAudioTTS(Base): | |||
| _FACTORY_NAME = "Fish Audio" | |||
| def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"): | |||
| if not base_url: | |||
| base_url = "https://api.fish.audio/v1/tts" | |||
| @@ -94,13 +96,11 @@ class FishAudioTTS(Base): | |||
| with httpx.Client() as client: | |||
| try: | |||
| with client.stream( | |||
| method="POST", | |||
| url=self.base_url, | |||
| content=ormsgpack.packb( | |||
| request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC | |||
| ), | |||
| headers=self.headers, | |||
| timeout=None, | |||
| method="POST", | |||
| url=self.base_url, | |||
| content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), | |||
| headers=self.headers, | |||
| timeout=None, | |||
| ) as response: | |||
| if response.status_code == HTTPStatus.OK: | |||
| for chunk in response.iter_bytes(): | |||
| @@ -115,6 +115,8 @@ class FishAudioTTS(Base): | |||
| class QwenTTS(Base): | |||
| _FACTORY_NAME = "Tongyi-Qianwen" | |||
| def __init__(self, key, model_name, base_url=""): | |||
| import dashscope | |||
| @@ -122,10 +124,11 @@ class QwenTTS(Base): | |||
| dashscope.api_key = key | |||
| def tts(self, text): | |||
| from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse | |||
| from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult | |||
| from collections import deque | |||
| from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse | |||
| from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer | |||
| class Callback(ResultCallback): | |||
| def __init__(self) -> None: | |||
| self.dque = deque() | |||
| @@ -159,10 +162,7 @@ class QwenTTS(Base): | |||
| text = self.normalize_text(text) | |||
| callback = Callback() | |||
| SpeechSynthesizer.call(model=self.model_name, | |||
| text=text, | |||
| callback=callback, | |||
| format="mp3") | |||
| SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3") | |||
| try: | |||
| for data in callback._run(): | |||
| yield data | |||
| @@ -173,24 +173,19 @@ class QwenTTS(Base): | |||
| class OpenAITTS(Base): | |||
| _FACTORY_NAME = "OpenAI" | |||
| def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): | |||
| if not base_url: | |||
| base_url = "https://api.openai.com/v1" | |||
| self.api_key = key | |||
| self.model_name = model_name | |||
| self.base_url = base_url | |||
| self.headers = { | |||
| "Authorization": f"Bearer {self.api_key}", | |||
| "Content-Type": "application/json" | |||
| } | |||
| self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} | |||
| def tts(self, text, voice="alloy"): | |||
| text = self.normalize_text(text) | |||
| payload = { | |||
| "model": self.model_name, | |||
| "voice": voice, | |||
| "input": text | |||
| } | |||
| payload = {"model": self.model_name, "voice": voice, "input": text} | |||
| response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True) | |||
| @@ -201,7 +196,8 @@ class OpenAITTS(Base): | |||
| yield chunk | |||
| class SparkTTS: | |||
| class SparkTTS(Base): | |||
| _FACTORY_NAME = "XunFei Spark" | |||
| STATUS_FIRST_FRAME = 0 | |||
| STATUS_CONTINUE_FRAME = 1 | |||
| STATUS_LAST_FRAME = 2 | |||
| @@ -219,29 +215,23 @@ class SparkTTS: | |||
| # 生成url | |||
| def create_url(self): | |||
| url = 'wss://tts-api.xfyun.cn/v2/tts' | |||
| url = "wss://tts-api.xfyun.cn/v2/tts" | |||
| now = datetime.now() | |||
| date = format_date_time(mktime(now.timetuple())) | |||
| signature_origin = "host: " + "ws-api.xfyun.cn" + "\n" | |||
| signature_origin += "date: " + date + "\n" | |||
| signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" | |||
| signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), | |||
| digestmod=hashlib.sha256).digest() | |||
| signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') | |||
| authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( | |||
| self.APIKey, "hmac-sha256", "host date request-line", signature_sha) | |||
| authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') | |||
| v = { | |||
| "authorization": authorization, | |||
| "date": date, | |||
| "host": "ws-api.xfyun.cn" | |||
| } | |||
| url = url + '?' + urlencode(v) | |||
| signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest() | |||
| signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8") | |||
| authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha) | |||
| authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") | |||
| v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"} | |||
| url = url + "?" + urlencode(v) | |||
| return url | |||
| def tts(self, text): | |||
| BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"} | |||
| Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')} | |||
| Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")} | |||
| CommonArgs = {"app_id": self.APPID} | |||
| audio_queue = self.audio_queue | |||
| model_name = self.model_name | |||
| @@ -273,9 +263,7 @@ class SparkTTS: | |||
| def on_open(self, ws): | |||
| def run(*args): | |||
| d = {"common": CommonArgs, | |||
| "business": BusinessArgs, | |||
| "data": Data} | |||
| d = {"common": CommonArgs, "business": BusinessArgs, "data": Data} | |||
| ws.send(json.dumps(d)) | |||
| thread.start_new_thread(run, ()) | |||
| @@ -283,44 +271,32 @@ class SparkTTS: | |||
| wsUrl = self.create_url() | |||
| websocket.enableTrace(False) | |||
| a = Callback() | |||
| ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, | |||
| on_message=a.on_message) | |||
| ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message) | |||
| status_code = 0 | |||
| ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) | |||
| while True: | |||
| audio_chunk = self.audio_queue.get() | |||
| if audio_chunk is None: | |||
| if status_code == 0: | |||
| raise Exception( | |||
| f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.") | |||
| raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.") | |||
| else: | |||
| break | |||
| status_code = 1 | |||
| yield audio_chunk | |||
| class XinferenceTTS: | |||
| class XinferenceTTS(Base): | |||
| _FACTORY_NAME = "Xinference" | |||
| def __init__(self, key, model_name, **kwargs): | |||
| self.base_url = kwargs.get("base_url", None) | |||
| self.model_name = model_name | |||
| self.headers = { | |||
| "accept": "application/json", | |||
| "Content-Type": "application/json" | |||
| } | |||
| self.headers = {"accept": "application/json", "Content-Type": "application/json"} | |||
| def tts(self, text, voice="中文女", stream=True): | |||
| payload = { | |||
| "model": self.model_name, | |||
| "input": text, | |||
| "voice": voice | |||
| } | |||
| payload = {"model": self.model_name, "input": text, "voice": voice} | |||
| response = requests.post( | |||
| f"{self.base_url}/v1/audio/speech", | |||
| headers=self.headers, | |||
| json=payload, | |||
| stream=stream | |||
| ) | |||
| response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream) | |||
| if response.status_code != 200: | |||
| raise Exception(f"**Error**: {response.status_code}, {response.text}") | |||
| @@ -332,22 +308,16 @@ class XinferenceTTS: | |||
| class OllamaTTS(Base): | |||
| def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"): | |||
| if not base_url: | |||
| if not base_url: | |||
| base_url = "https://api.ollama.ai/v1" | |||
| self.model_name = model_name | |||
| self.base_url = base_url | |||
| self.headers = { | |||
| "Content-Type": "application/json" | |||
| } | |||
| self.headers = {"Content-Type": "application/json"} | |||
| if key and key != "x": | |||
| self.headers["Authorization"] = f"Bear {key}" | |||
| def tts(self, text, voice="standard-voice"): | |||
| payload = { | |||
| "model": self.model_name, | |||
| "voice": voice, | |||
| "input": text | |||
| } | |||
| payload = {"model": self.model_name, "voice": voice, "input": text} | |||
| response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True) | |||
| @@ -359,30 +329,19 @@ class OllamaTTS(Base): | |||
| yield chunk | |||
| class GPUStackTTS: | |||
| class GPUStackTTS(Base): | |||
| _FACTORY_NAME = "GPUStack" | |||
| def __init__(self, key, model_name, **kwargs): | |||
| self.base_url = kwargs.get("base_url", None) | |||
| self.api_key = key | |||
| self.model_name = model_name | |||
| self.headers = { | |||
| "accept": "application/json", | |||
| "Content-Type": "application/json", | |||
| "Authorization": f"Bearer {self.api_key}" | |||
| } | |||
| self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} | |||
| def tts(self, text, voice="Chinese Female", stream=True): | |||
| payload = { | |||
| "model": self.model_name, | |||
| "input": text, | |||
| "voice": voice | |||
| } | |||
| payload = {"model": self.model_name, "input": text, "voice": voice} | |||
| response = requests.post( | |||
| f"{self.base_url}/v1/audio/speech", | |||
| headers=self.headers, | |||
| json=payload, | |||
| stream=stream | |||
| ) | |||
| response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream) | |||
| if response.status_code != 200: | |||
| raise Exception(f"**Error**: {response.status_code}, {response.text}") | |||
| @@ -393,16 +352,15 @@ class GPUStackTTS: | |||
| class SILICONFLOWTTS(Base): | |||
| _FACTORY_NAME = "SILICONFLOW" | |||
| def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"): | |||
| if not base_url: | |||
| base_url = "https://api.siliconflow.cn/v1" | |||
| self.api_key = key | |||
| self.model_name = model_name | |||
| self.base_url = base_url | |||
| self.headers = { | |||
| "Authorization": f"Bearer {self.api_key}", | |||
| "Content-Type": "application/json" | |||
| } | |||
| self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} | |||
| def tts(self, text, voice="anna"): | |||
| text = self.normalize_text(text) | |||
| @@ -414,7 +372,7 @@ class SILICONFLOWTTS(Base): | |||
| "sample_rate": 123, | |||
| "stream": True, | |||
| "speed": 1, | |||
| "gain": 0 | |||
| "gain": 0, | |||
| } | |||
| response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload) | |||