浏览代码

Refa: automatic LLMs registration (#8651)

### What problem does this PR solve?

Support automatic LLMs registration.

### Type of change

- [x] Refactoring
tags/v0.20.0
Yongteng Lei 4 个月前
父节点
当前提交
f8a6987f1e
没有帐户链接到提交者的电子邮件
共有 7 个文件被更改,包括 617 次插入874 次删除
  1. 41
    277
      rag/llm/__init__.py
  2. 92
    12
      rag/llm/chat_model.py
  3. 166
    163
      rag/llm/cv_model.py
  4. 152
    193
      rag/llm/embedding_model.py
  5. 74
    92
      rag/llm/rerank_model.py
  6. 42
    45
      rag/llm/sequence2txt_model.py
  7. 50
    92
      rag/llm/tts_model.py

+ 41
- 277
rag/llm/__init__.py 查看文件

@@ -15,289 +15,53 @@
#
# AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
#
from .embedding_model import (
OllamaEmbed,
LocalAIEmbed,
OpenAIEmbed,
AzureEmbed,
XinferenceEmbed,
QWenEmbed,
ZhipuEmbed,
FastEmbed,
YoudaoEmbed,
BaiChuanEmbed,
JinaEmbed,
DefaultEmbedding,
MistralEmbed,
BedrockEmbed,
GeminiEmbed,
NvidiaEmbed,
LmStudioEmbed,
OpenAI_APIEmbed,
CoHereEmbed,
TogetherAIEmbed,
PerfXCloudEmbed,
UpstageEmbed,
SILICONFLOWEmbed,
ReplicateEmbed,
BaiduYiyanEmbed,
VoyageEmbed,
HuggingFaceEmbed,
VolcEngineEmbed,
GPUStackEmbed,
NovitaEmbed,
GiteeEmbed
)
from .chat_model import (
GptTurbo,
AzureChat,
ZhipuChat,
QWenChat,
OllamaChat,
LocalAIChat,
XinferenceChat,
MoonshotChat,
DeepSeekChat,
VolcEngineChat,
BaiChuanChat,
MiniMaxChat,
MistralChat,
GeminiChat,
BedrockChat,
GroqChat,
OpenRouterChat,
StepFunChat,
NvidiaChat,
LmStudioChat,
OpenAI_APIChat,
CoHereChat,
LeptonAIChat,
TogetherAIChat,
PerfXCloudChat,
UpstageChat,
NovitaAIChat,
SILICONFLOWChat,
PPIOChat,
YiChat,
ReplicateChat,
HunyuanChat,
SparkChat,
BaiduYiyanChat,
AnthropicChat,
GoogleChat,
HuggingFaceChat,
GPUStackChat,
ModelScopeChat,
GiteeChat
)

from .cv_model import (
GptV4,
AzureGptV4,
OllamaCV,
XinferenceCV,
QWenCV,
Zhipu4V,
LocalCV,
GeminiCV,
OpenRouterCV,
LocalAICV,
NvidiaCV,
LmStudioCV,
StepFunCV,
OpenAI_APICV,
TogetherAICV,
YiCV,
HunyuanCV,
AnthropicCV,
SILICONFLOWCV,
GPUStackCV,
GoogleCV,
)
import importlib
import inspect

from .rerank_model import (
LocalAIRerank,
DefaultRerank,
JinaRerank,
YoudaoRerank,
XInferenceRerank,
NvidiaRerank,
LmStudioRerank,
OpenAI_APIRerank,
CoHereRerank,
TogetherAIRerank,
SILICONFLOWRerank,
BaiduYiyanRerank,
VoyageRerank,
QWenRerank,
GPUStackRerank,
HuggingfaceRerank,
NovitaRerank,
GiteeRerank
)
ChatModel = globals().get("ChatModel", {})
CvModel = globals().get("CvModel", {})
EmbeddingModel = globals().get("EmbeddingModel", {})
RerankModel = globals().get("RerankModel", {})
Seq2txtModel = globals().get("Seq2txtModel", {})
TTSModel = globals().get("TTSModel", {})

from .sequence2txt_model import (
GPTSeq2txt,
QWenSeq2txt,
AzureSeq2txt,
XinferenceSeq2txt,
TencentCloudSeq2txt,
GPUStackSeq2txt,
GiteeSeq2txt
)

from .tts_model import (
FishAudioTTS,
QwenTTS,
OpenAITTS,
SparkTTS,
XinferenceTTS,
GPUStackTTS,
SILICONFLOWTTS,
)

EmbeddingModel = {
"Ollama": OllamaEmbed,
"LocalAI": LocalAIEmbed,
"OpenAI": OpenAIEmbed,
"Azure-OpenAI": AzureEmbed,
"Xinference": XinferenceEmbed,
"Tongyi-Qianwen": QWenEmbed,
"ZHIPU-AI": ZhipuEmbed,
"FastEmbed": FastEmbed,
"Youdao": YoudaoEmbed,
"BaiChuan": BaiChuanEmbed,
"Jina": JinaEmbed,
"BAAI": DefaultEmbedding,
"Mistral": MistralEmbed,
"Bedrock": BedrockEmbed,
"Gemini": GeminiEmbed,
"NVIDIA": NvidiaEmbed,
"LM-Studio": LmStudioEmbed,
"OpenAI-API-Compatible": OpenAI_APIEmbed,
"VLLM": OpenAI_APIEmbed,
"Cohere": CoHereEmbed,
"TogetherAI": TogetherAIEmbed,
"PerfXCloud": PerfXCloudEmbed,
"Upstage": UpstageEmbed,
"SILICONFLOW": SILICONFLOWEmbed,
"Replicate": ReplicateEmbed,
"BaiduYiyan": BaiduYiyanEmbed,
"Voyage AI": VoyageEmbed,
"HuggingFace": HuggingFaceEmbed,
"VolcEngine": VolcEngineEmbed,
"GPUStack": GPUStackEmbed,
"NovitaAI": NovitaEmbed,
"GiteeAI": GiteeEmbed
MODULE_MAPPING = {
"chat_model": ChatModel,
"cv_model": CvModel,
"embedding_model": EmbeddingModel,
"rerank_model": RerankModel,
"sequence2txt_model": Seq2txtModel,
"tts_model": TTSModel,
}

CvModel = {
"OpenAI": GptV4,
"Azure-OpenAI": AzureGptV4,
"Ollama": OllamaCV,
"Xinference": XinferenceCV,
"Tongyi-Qianwen": QWenCV,
"ZHIPU-AI": Zhipu4V,
"Moonshot": LocalCV,
"Gemini": GeminiCV,
"OpenRouter": OpenRouterCV,
"LocalAI": LocalAICV,
"NVIDIA": NvidiaCV,
"LM-Studio": LmStudioCV,
"StepFun": StepFunCV,
"OpenAI-API-Compatible": OpenAI_APICV,
"VLLM": OpenAI_APICV,
"TogetherAI": TogetherAICV,
"01.AI": YiCV,
"Tencent Hunyuan": HunyuanCV,
"Anthropic": AnthropicCV,
"SILICONFLOW": SILICONFLOWCV,
"GPUStack": GPUStackCV,
"Google Cloud": GoogleCV
}
package_name = __name__

ChatModel = {
"OpenAI": GptTurbo,
"Azure-OpenAI": AzureChat,
"ZHIPU-AI": ZhipuChat,
"Tongyi-Qianwen": QWenChat,
"Ollama": OllamaChat,
"LocalAI": LocalAIChat,
"Xinference": XinferenceChat,
"Moonshot": MoonshotChat,
"DeepSeek": DeepSeekChat,
"VolcEngine": VolcEngineChat,
"BaiChuan": BaiChuanChat,
"MiniMax": MiniMaxChat,
"Mistral": MistralChat,
"Gemini": GeminiChat,
"Bedrock": BedrockChat,
"Groq": GroqChat,
"OpenRouter": OpenRouterChat,
"StepFun": StepFunChat,
"NVIDIA": NvidiaChat,
"LM-Studio": LmStudioChat,
"OpenAI-API-Compatible": OpenAI_APIChat,
"VLLM": OpenAI_APIChat,
"Cohere": CoHereChat,
"LeptonAI": LeptonAIChat,
"TogetherAI": TogetherAIChat,
"PerfXCloud": PerfXCloudChat,
"Upstage": UpstageChat,
"NovitaAI": NovitaAIChat,
"SILICONFLOW": SILICONFLOWChat,
"PPIO": PPIOChat,
"01.AI": YiChat,
"Replicate": ReplicateChat,
"Tencent Hunyuan": HunyuanChat,
"XunFei Spark": SparkChat,
"BaiduYiyan": BaiduYiyanChat,
"Anthropic": AnthropicChat,
"Google Cloud": GoogleChat,
"HuggingFace": HuggingFaceChat,
"GPUStack": GPUStackChat,
"ModelScope":ModelScopeChat,
"GiteeAI": GiteeChat
}
for module_name, mapping_dict in MODULE_MAPPING.items():
full_module_name = f"{package_name}.{module_name}"
module = importlib.import_module(full_module_name)

RerankModel = {
"LocalAI": LocalAIRerank,
"BAAI": DefaultRerank,
"Jina": JinaRerank,
"Youdao": YoudaoRerank,
"Xinference": XInferenceRerank,
"NVIDIA": NvidiaRerank,
"LM-Studio": LmStudioRerank,
"OpenAI-API-Compatible": OpenAI_APIRerank,
"VLLM": CoHereRerank,
"Cohere": CoHereRerank,
"TogetherAI": TogetherAIRerank,
"SILICONFLOW": SILICONFLOWRerank,
"BaiduYiyan": BaiduYiyanRerank,
"Voyage AI": VoyageRerank,
"Tongyi-Qianwen": QWenRerank,
"GPUStack": GPUStackRerank,
"HuggingFace": HuggingfaceRerank,
"NovitaAI": NovitaRerank,
"GiteeAI": GiteeRerank
}
base_class = None
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and name == "Base":
base_class = obj
break
if base_class is None:
continue

Seq2txtModel = {
"OpenAI": GPTSeq2txt,
"Tongyi-Qianwen": QWenSeq2txt,
"Azure-OpenAI": AzureSeq2txt,
"Xinference": XinferenceSeq2txt,
"Tencent Cloud": TencentCloudSeq2txt,
"GPUStack": GPUStackSeq2txt,
"GiteeAI": GiteeSeq2txt
}
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
if isinstance(obj._FACTORY_NAME, list):
for factory_name in obj._FACTORY_NAME:
mapping_dict[factory_name] = obj
else:
mapping_dict[obj._FACTORY_NAME] = obj

TTSModel = {
"Fish Audio": FishAudioTTS,
"Tongyi-Qianwen": QwenTTS,
"OpenAI": OpenAITTS,
"XunFei Spark": SparkTTS,
"Xinference": XinferenceTTS,
"GPUStack": GPUStackTTS,
"SILICONFLOW": SILICONFLOWTTS,
}
__all__ = [
"ChatModel",
"CvModel",
"EmbeddingModel",
"RerankModel",
"Seq2txtModel",
"TTSModel",
]

+ 92
- 12
rag/llm/chat_model.py 查看文件

@@ -142,11 +142,7 @@ class Base(ABC):
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"

def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({
"name": name,
"args": args,
"result": res
}, ensure_ascii=False, indent=2) + "</tool_call>"
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"

def _append_history(self, hist, tool_call, tool_res):
hist.append(
@@ -191,10 +187,10 @@ class Base(ABC):
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries+1):
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds*2):
for _ in range(self.max_rounds * 2):
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
tk_count += self.total_token_count(response)
if any([not response.choices, not response.choices[0].message]):
@@ -269,7 +265,7 @@ class Base(ABC):
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds*2):
for _ in range(self.max_rounds * 2):
reasoning_start = False
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
final_tool_calls = {}
@@ -430,6 +426,8 @@ class Base(ABC):


class GptTurbo(Base):
_FACTORY_NAME = "OpenAI"

def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
@@ -437,6 +435,8 @@ class GptTurbo(Base):


class MoonshotChat(Base):
_FACTORY_NAME = "Moonshot"

def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1", **kwargs):
if not base_url:
base_url = "https://api.moonshot.cn/v1"
@@ -444,6 +444,8 @@ class MoonshotChat(Base):


class XinferenceChat(Base):
_FACTORY_NAME = "Xinference"

def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -452,6 +454,8 @@ class XinferenceChat(Base):


class HuggingFaceChat(Base):
_FACTORY_NAME = "HuggingFace"

def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -460,6 +464,8 @@ class HuggingFaceChat(Base):


class ModelScopeChat(Base):
_FACTORY_NAME = "ModelScope"

def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -468,6 +474,8 @@ class ModelScopeChat(Base):


class DeepSeekChat(Base):
_FACTORY_NAME = "DeepSeek"

def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1", **kwargs):
if not base_url:
base_url = "https://api.deepseek.com/v1"
@@ -475,6 +483,8 @@ class DeepSeekChat(Base):


class AzureChat(Base):
_FACTORY_NAME = "Azure-OpenAI"

def __init__(self, key, model_name, base_url, **kwargs):
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
@@ -484,6 +494,8 @@ class AzureChat(Base):


class BaiChuanChat(Base):
_FACTORY_NAME = "BaiChuan"

def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
@@ -557,6 +569,8 @@ class BaiChuanChat(Base):


class QWenChat(Base):
_FACTORY_NAME = "Tongyi-Qianwen"

def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
if not base_url:
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
@@ -565,6 +579,8 @@ class QWenChat(Base):


class ZhipuChat(Base):
_FACTORY_NAME = "ZHIPU-AI"

def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -630,6 +646,8 @@ class ZhipuChat(Base):


class OllamaChat(Base):
_FACTORY_NAME = "Ollama"

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -694,6 +712,8 @@ class OllamaChat(Base):


class LocalAIChat(Base):
_FACTORY_NAME = "LocalAI"

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -752,6 +772,8 @@ class LocalLLM(Base):


class VolcEngineChat(Base):
_FACTORY_NAME = "VolcEngine"

def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
"""
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
@@ -765,6 +787,8 @@ class VolcEngineChat(Base):


class MiniMaxChat(Base):
_FACTORY_NAME = "MiniMax"

def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -843,6 +867,8 @@ class MiniMaxChat(Base):


class MistralChat(Base):
_FACTORY_NAME = "Mistral"

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -896,6 +922,8 @@ class MistralChat(Base):


class BedrockChat(Base):
_FACTORY_NAME = "Bedrock"

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -978,6 +1006,8 @@ class BedrockChat(Base):


class GeminiChat(Base):
_FACTORY_NAME = "Gemini"

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -997,6 +1027,7 @@ class GeminiChat(Base):

def _chat(self, history, gen_conf):
from google.generativeai.types import content_types

system = history[0]["content"] if history and history[0]["role"] == "system" else ""
hist = []
for item in history:
@@ -1019,6 +1050,7 @@ class GeminiChat(Base):

def chat_streamly(self, system, history, gen_conf):
from google.generativeai.types import content_types

gen_conf = self._clean_conf(gen_conf)
if system:
self.model._system_instruction = content_types.to_content(system)
@@ -1042,6 +1074,8 @@ class GeminiChat(Base):


class GroqChat(Base):
_FACTORY_NAME = "Groq"

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -1086,6 +1120,8 @@ class GroqChat(Base):

## openrouter
class OpenRouterChat(Base):
_FACTORY_NAME = "OpenRouter"

def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs):
if not base_url:
base_url = "https://openrouter.ai/api/v1"
@@ -1093,6 +1129,8 @@ class OpenRouterChat(Base):


class StepFunChat(Base):
_FACTORY_NAME = "StepFun"

def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs):
if not base_url:
base_url = "https://api.stepfun.com/v1"
@@ -1100,6 +1138,8 @@ class StepFunChat(Base):


class NvidiaChat(Base):
_FACTORY_NAME = "NVIDIA"

def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1", **kwargs):
if not base_url:
base_url = "https://integrate.api.nvidia.com/v1"
@@ -1107,6 +1147,8 @@ class NvidiaChat(Base):


class LmStudioChat(Base):
_FACTORY_NAME = "LM-Studio"

def __init__(self, key, model_name, base_url, **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -1117,6 +1159,8 @@ class LmStudioChat(Base):


class OpenAI_APIChat(Base):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]

def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
@@ -1125,6 +1169,8 @@ class OpenAI_APIChat(Base):


class PPIOChat(Base):
_FACTORY_NAME = "PPIO"

def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs):
if not base_url:
base_url = "https://api.ppinfra.com/v3/openai"
@@ -1132,6 +1178,8 @@ class PPIOChat(Base):


class CoHereChat(Base):
_FACTORY_NAME = "Cohere"

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -1207,6 +1255,8 @@ class CoHereChat(Base):


class LeptonAIChat(Base):
_FACTORY_NAME = "LeptonAI"

def __init__(self, key, model_name, base_url=None, **kwargs):
if not base_url:
base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1")
@@ -1214,6 +1264,8 @@ class LeptonAIChat(Base):


class TogetherAIChat(Base):
_FACTORY_NAME = "TogetherAI"

def __init__(self, key, model_name, base_url="https://api.together.xyz/v1", **kwargs):
if not base_url:
base_url = "https://api.together.xyz/v1"
@@ -1221,6 +1273,8 @@ class TogetherAIChat(Base):


class PerfXCloudChat(Base):
_FACTORY_NAME = "PerfXCloud"

def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs):
if not base_url:
base_url = "https://cloud.perfxlab.cn/v1"
@@ -1228,6 +1282,8 @@ class PerfXCloudChat(Base):


class UpstageChat(Base):
_FACTORY_NAME = "Upstage"

def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs):
if not base_url:
base_url = "https://api.upstage.ai/v1/solar"
@@ -1235,6 +1291,8 @@ class UpstageChat(Base):


class NovitaAIChat(Base):
_FACTORY_NAME = "NovitaAI"

def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs):
if not base_url:
base_url = "https://api.novita.ai/v3/openai"
@@ -1242,6 +1300,8 @@ class NovitaAIChat(Base):


class SILICONFLOWChat(Base):
_FACTORY_NAME = "SILICONFLOW"

def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
@@ -1249,6 +1309,8 @@ class SILICONFLOWChat(Base):


class YiChat(Base):
_FACTORY_NAME = "01.AI"

def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs):
if not base_url:
base_url = "https://api.lingyiwanwu.com/v1"
@@ -1256,6 +1318,8 @@ class YiChat(Base):


class GiteeChat(Base):
_FACTORY_NAME = "GiteeAI"

def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs):
if not base_url:
base_url = "https://ai.gitee.com/v1/"
@@ -1263,6 +1327,8 @@ class GiteeChat(Base):


class ReplicateChat(Base):
_FACTORY_NAME = "Replicate"

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -1302,6 +1368,8 @@ class ReplicateChat(Base):


class HunyuanChat(Base):
_FACTORY_NAME = "Tencent Hunyuan"

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -1378,6 +1446,8 @@ class HunyuanChat(Base):


class SparkChat(Base):
_FACTORY_NAME = "XunFei Spark"

def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs):
if not base_url:
base_url = "https://spark-api-open.xf-yun.com/v1"
@@ -1398,6 +1468,8 @@ class SparkChat(Base):


class BaiduYiyanChat(Base):
_FACTORY_NAME = "BaiduYiyan"

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -1444,6 +1516,8 @@ class BaiduYiyanChat(Base):


class AnthropicChat(Base):
_FACTORY_NAME = "Anthropic"

def __init__(self, key, model_name, base_url="https://api.anthropic.com/v1/", **kwargs):
if not base_url:
base_url = "https://api.anthropic.com/v1/"
@@ -1451,6 +1525,8 @@ class AnthropicChat(Base):


class GoogleChat(Base):
_FACTORY_NAME = "Google Cloud"

def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)

@@ -1529,9 +1605,11 @@ class GoogleChat(Base):
if "role" in item and item["role"] == "assistant":
item["role"] = "model"
if "content" in item:
item["parts"] = [{
"text": item.pop("content"),
}]
item["parts"] = [
{
"text": item.pop("content"),
}
]

response = self.client.generate_content(hist, generation_config=gen_conf)
ans = response.text
@@ -1587,8 +1665,10 @@ class GoogleChat(Base):


class GPUStackChat(Base):
_FACTORY_NAME = "GPUStack"

def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)
super().__init__(key, model_name, base_url, **kwargs)

+ 166
- 163
rag/llm/cv_model.py 查看文件

@@ -57,7 +57,7 @@ class Base(ABC):
model=self.model_name,
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)
top_p=gen_conf.get("top_p", 0.7),
)
return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e:
@@ -79,7 +79,7 @@ class Base(ABC):
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
stream=True
stream=True,
)
for resp in response:
if not resp.choices[0].delta.content:
@@ -87,8 +87,7 @@ class Base(ABC):
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
@@ -117,13 +116,12 @@ class Base(ABC):
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
},
{
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
},
],
}
@@ -136,9 +134,7 @@ class Base(ABC):
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
},
{
"type": "text",
@@ -156,14 +152,13 @@ class Base(ABC):
"url": f"data:image/jpeg;base64,{b64}",
},
},
{
"type": "text",
"text": text
},
{"type": "text", "text": text},
]


class GptV4(Base):
_FACTORY_NAME = "OpenAI"

def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
@@ -181,7 +176,7 @@ class GptV4(Base):

res = self.client.chat.completions.create(
model=self.model_name,
messages=prompt
messages=prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens

@@ -197,9 +192,11 @@ class GptV4(Base):


class AzureGptV4(Base):
_FACTORY_NAME = "Azure-OpenAI"

def __init__(self, key, model_name, lang="Chinese", **kwargs):
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01')
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
self.lang = lang
@@ -212,10 +209,7 @@ class AzureGptV4(Base):
if "text" in c:
c["type"] = "text"

res = self.client.chat.completions.create(
model=self.model_name,
messages=prompt
)
res = self.client.chat.completions.create(model=self.model_name, messages=prompt)
return res.choices[0].message.content.strip(), res.usage.total_tokens

def describe_with_prompt(self, image, prompt=None):
@@ -230,8 +224,11 @@ class AzureGptV4(Base):


class QWenCV(Base):
_FACTORY_NAME = "Tongyi-Qianwen"

def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
import dashscope

dashscope.api_key = key
self.model_name = model_name
self.lang = lang
@@ -247,12 +244,11 @@ class QWenCV(Base):
{
"role": "user",
"content": [
{"image": f"file://{path}"},
{
"image": f"file://{path}"
},
{
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
},
],
}
@@ -270,11 +266,9 @@ class QWenCV(Base):
{
"role": "user",
"content": [
{"image": f"file://{path}"},
{
"image": f"file://{path}"
},
{
"text": prompt if prompt else vision_llm_describe_prompt(),
"text": prompt if prompt else vision_llm_describe_prompt(),
},
],
}
@@ -290,9 +284,10 @@ class QWenCV(Base):
from http import HTTPStatus

from dashscope import MultiModalConversation

response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
return response.message, 0

def describe_with_prompt(self, image, prompt=None):
@@ -303,33 +298,36 @@ class QWenCV(Base):
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
return response.message, 0

def chat(self, system, history, gen_conf, image=""):
from http import HTTPStatus

from dashscope import MultiModalConversation

if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]

for his in history:
if his["role"] == "user":
his["content"] = self.chat_prompt(his["content"], image)
response = MultiModalConversation.call(model=self.model_name, messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7))
response = MultiModalConversation.call(
model=self.model_name,
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
)

ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
ans = response.output.choices[0]['message']['content']
ans = response.output.choices[0]["message"]["content"]
if isinstance(ans, list):
ans = ans[0]["text"] if ans else ""
tk_count += response.usage.total_tokens
if response.output.choices[0].get("finish_reason", "") == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, tk_count

return "**ERROR**: " + response.message, tk_count
@@ -338,6 +336,7 @@ class QWenCV(Base):
from http import HTTPStatus

from dashscope import MultiModalConversation

if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]

@@ -348,24 +347,25 @@ class QWenCV(Base):
ans = ""
tk_count = 0
try:
response = MultiModalConversation.call(model=self.model_name, messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
stream=True)
response = MultiModalConversation.call(
model=self.model_name,
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
stream=True,
)
for resp in response:
if resp.status_code == HTTPStatus.OK:
cnt = resp.output.choices[0]['message']['content']
cnt = resp.output.choices[0]["message"]["content"]
if isinstance(cnt, list):
cnt = cnt[0]["text"] if ans else ""
ans += cnt
tk_count = resp.usage.total_tokens
if resp.output.choices[0].get("finish_reason", "") == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
else:
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)

@@ -373,6 +373,8 @@ class QWenCV(Base):


class Zhipu4V(Base):
_FACTORY_NAME = "ZHIPU-AI"

def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
@@ -394,10 +396,7 @@ class Zhipu4V(Base):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)

res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt
)
res = self.client.chat.completions.create(model=self.model_name, messages=vision_prompt)
return res.choices[0].message.content.strip(), res.usage.total_tokens

def chat(self, system, history, gen_conf, image=""):
@@ -412,7 +411,7 @@ class Zhipu4V(Base):
model=self.model_name,
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)
top_p=gen_conf.get("top_p", 0.7),
)
return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e:
@@ -434,7 +433,7 @@ class Zhipu4V(Base):
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
stream=True
stream=True,
)
for resp in response:
if not resp.choices[0].delta.content:
@@ -442,8 +441,7 @@ class Zhipu4V(Base):
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
@@ -455,6 +453,8 @@ class Zhipu4V(Base):


class OllamaCV(Base):
_FACTORY_NAME = "Ollama"

def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
@@ -466,7 +466,7 @@ class OllamaCV(Base):
response = self.client.generate(
model=self.model_name,
prompt=prompt[0]["content"][1]["text"],
images=[image]
images=[image],
)
ans = response["response"].strip()
return ans, 128
@@ -507,7 +507,7 @@ class OllamaCV(Base):
model=self.model_name,
messages=history,
options=options,
keep_alive=-1
keep_alive=-1,
)

ans = response["message"]["content"].strip()
@@ -538,7 +538,7 @@ class OllamaCV(Base):
messages=history,
stream=True,
options=options,
keep_alive=-1
keep_alive=-1,
)
for resp in response:
if resp["done"]:
@@ -551,6 +551,8 @@ class OllamaCV(Base):


class LocalAICV(GptV4):
_FACTORY_NAME = "LocalAI"

def __init__(self, key, model_name, base_url, lang="Chinese"):
if not base_url:
raise ValueError("Local cv model url cannot be None")
@@ -561,6 +563,8 @@ class LocalAICV(GptV4):


class XinferenceCV(Base):
_FACTORY_NAME = "Xinference"

def __init__(self, key, model_name="", lang="Chinese", base_url=""):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
@@ -570,10 +574,7 @@ class XinferenceCV(Base):
def describe(self, image):
b64 = self.image2base64(image)

res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64)
)
res = self.client.chat.completions.create(model=self.model_name, messages=self.prompt(b64))
return res.choices[0].message.content.strip(), res.usage.total_tokens

def describe_with_prompt(self, image, prompt=None):
@@ -588,8 +589,11 @@ class XinferenceCV(Base):


class GeminiCV(Base):
_FACTORY_NAME = "Gemini"

def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import GenerativeModel, client

client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = model_name
@@ -599,18 +603,21 @@ class GeminiCV(Base):

def describe(self, image):
from PIL.Image import open
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."

prompt = (
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
)
b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64)))
input = [prompt, img]
res = self.model.generate_content(
input
)
res = self.model.generate_content(input)
return res.text, res.usage_metadata.total_token_count

def describe_with_prompt(self, image, prompt=None):
from PIL.Image import open

b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
img = open(BytesIO(base64.b64decode(b64)))
@@ -622,6 +629,7 @@ class GeminiCV(Base):

def chat(self, system, history, gen_conf, image=""):
from transformers import GenerationConfig

if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
try:
@@ -635,9 +643,7 @@ class GeminiCV(Base):
his.pop("content")
history[-1]["parts"].append("data:image/jpeg;base64," + image)

response = self.model.generate_content(history, generation_config=GenerationConfig(
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)))
response = self.model.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))

ans = response.text
return ans, response.usage_metadata.total_token_count
@@ -646,6 +652,7 @@ class GeminiCV(Base):

def chat_streamly(self, system, history, gen_conf, image=""):
from transformers import GenerationConfig

if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]

@@ -661,9 +668,11 @@ class GeminiCV(Base):
his.pop("content")
history[-1]["parts"].append("data:image/jpeg;base64," + image)

response = self.model.generate_content(history, generation_config=GenerationConfig(
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)), stream=True)
response = self.model.generate_content(
history,
generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)),
stream=True,
)

for resp in response:
if not resp.text:
@@ -677,6 +686,8 @@ class GeminiCV(Base):


class OpenRouterCV(GptV4):
_FACTORY_NAME = "OpenRouter"

def __init__(
self,
key,
@@ -692,6 +703,8 @@ class OpenRouterCV(GptV4):


class LocalCV(Base):
_FACTORY_NAME = "Moonshot"

def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
pass

@@ -700,6 +713,8 @@ class LocalCV(Base):


class NvidiaCV(Base):
_FACTORY_NAME = "NVIDIA"

def __init__(
self,
key,
@@ -726,9 +741,7 @@ class NvidiaCV(Base):
"content-type": "application/json",
"Authorization": f"Bearer {self.key}",
},
json={
"messages": self.prompt(b64)
},
json={"messages": self.prompt(b64)},
)
response = response.json()
return (
@@ -774,10 +787,7 @@ class NvidiaCV(Base):
return [
{
"role": "user",
"content": (
prompt if prompt else vision_llm_describe_prompt()
)
+ f' <img src="data:image/jpeg;base64,{b64}"/>',
"content": (prompt if prompt else vision_llm_describe_prompt()) + f' <img src="data:image/jpeg;base64,{b64}"/>',
}
]

@@ -791,6 +801,8 @@ class NvidiaCV(Base):


class StepFunCV(GptV4):
_FACTORY_NAME = "StepFun"

def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
if not base_url:
base_url = "https://api.stepfun.com/v1"
@@ -800,6 +812,8 @@ class StepFunCV(GptV4):


class LmStudioCV(GptV4):
_FACTORY_NAME = "LM-Studio"

def __init__(self, key, model_name, lang="Chinese", base_url=""):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -810,6 +824,8 @@ class LmStudioCV(GptV4):


class OpenAI_APICV(GptV4):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]

def __init__(self, key, model_name, lang="Chinese", base_url=""):
if not base_url:
raise ValueError("url cannot be None")
@@ -820,6 +836,8 @@ class OpenAI_APICV(GptV4):


class TogetherAICV(GptV4):
_FACTORY_NAME = "TogetherAI"

def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
@@ -827,20 +845,38 @@ class TogetherAICV(GptV4):


class YiCV(GptV4):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
_FACTORY_NAME = "01.AI"

def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://api.lingyiwanwu.com/v1",
):
if not base_url:
base_url = "https://api.lingyiwanwu.com/v1"
super().__init__(key, model_name, lang, base_url)


class SILICONFLOWCV(GptV4):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1",):
_FACTORY_NAME = "SILICONFLOW"

def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://api.siliconflow.cn/v1",
):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, lang, base_url)


class HunyuanCV(Base):
_FACTORY_NAME = "Tencent Hunyuan"

def __init__(self, key, model_name, lang="Chinese", base_url=None):
from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client
@@ -895,14 +931,13 @@ class HunyuanCV(Base):
"Contents": [
{
"Type": "image_url",
"ImageUrl": {
"Url": f"data:image/jpeg;base64,{b64}"
},
"ImageUrl": {"Url": f"data:image/jpeg;base64,{b64}"},
},
{
"Type": "text",
"Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
"Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
},
],
}
@@ -910,6 +945,8 @@ class HunyuanCV(Base):


class AnthropicCV(Base):
_FACTORY_NAME = "Anthropic"

def __init__(self, key, model_name, base_url=None):
import anthropic

@@ -933,38 +970,29 @@ class AnthropicCV(Base):
"data": b64,
},
},
{
"type": "text",
"text": prompt
}
{"type": "text", "text": prompt},
],
}
]

def describe(self, image):
b64 = self.image2base64(image)
prompt = self.prompt(b64,
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
)

response = self.client.messages.create(
model=self.model_name,
max_tokens=self.max_tokens,
messages=prompt
prompt = self.prompt(
b64,
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
)
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]

response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]

def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())

response = self.client.messages.create(
model=self.model_name,
max_tokens=self.max_tokens,
messages=prompt
)
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]

def chat(self, system, history, gen_conf):
if "presence_penalty" in gen_conf:
@@ -984,11 +1012,7 @@ class AnthropicCV(Base):
).to_dict()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return (
ans,
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
@@ -1014,7 +1038,7 @@ class AnthropicCV(Base):
**gen_conf,
)
for res in response:
if res.type == 'content_block_delta':
if res.type == "content_block_delta":
if res.delta.type == "thinking_delta" and res.delta.thinking:
if ans.find("<think>") < 0:
ans += "<think>"
@@ -1030,7 +1054,10 @@ class AnthropicCV(Base):

yield total_tokens


class GPUStackCV(GptV4):
_FACTORY_NAME = "GPUStack"

def __init__(self, key, model_name, lang="Chinese", base_url=""):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -1041,11 +1068,13 @@ class GPUStackCV(GptV4):


class GoogleCV(Base):
_FACTORY_NAME = "Google Cloud"

def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
import base64

from google.oauth2 import service_account
key = json.loads(key)
access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
project_id = key.get("google_project_id", "")
@@ -1079,9 +1108,12 @@ class GoogleCV(Base):
self.client = glm.GenerativeModel(model_name=self.model_name)

def describe(self, image):
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
prompt = (
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
)

if "claude" in self.model_name:
b64 = self.image2base64(image)
vision_prompt = [
@@ -1096,28 +1128,22 @@ class GoogleCV(Base):
"data": b64,
},
},
{
"type": "text",
"text": prompt
}
{"type": "text", "text": prompt},
],
}
]
response = self.client.messages.create(
model=self.model_name,
max_tokens=8192,
messages=vision_prompt
messages=vision_prompt,
)
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
else:
import vertexai.generative_models as glm
b64 = self.image2base64(image)
# Create proper image part for Gemini
image_part = glm.Part.from_data(
data=base64.b64decode(b64),
mime_type="image/jpeg"
)
image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
input = [prompt, image_part]
res = self.client.generate_content(input)
return res.text, res.usage_metadata.total_token_count
@@ -1137,29 +1163,19 @@ class GoogleCV(Base):
"data": b64,
},
},
{
"type": "text",
"text": prompt if prompt else vision_llm_describe_prompt()
}
{"type": "text", "text": prompt if prompt else vision_llm_describe_prompt()},
],
}
]
response = self.client.messages.create(
model=self.model_name,
max_tokens=8192,
messages=vision_prompt
)
response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=vision_prompt)
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
else:
import vertexai.generative_models as glm

b64 = self.image2base64(image)
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
# Create proper image part for Gemini
image_part = glm.Part.from_data(
data=base64.b64decode(b64),
mime_type="image/jpeg"
)
image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
input = [vision_prompt, image_part]
res = self.client.generate_content(input)
return res.text, res.usage_metadata.total_token_count
@@ -1180,25 +1196,17 @@ class GoogleCV(Base):
"data": image,
},
},
{
"type": "text",
"text": his["content"]
}
{"type": "text", "text": his["content"]},
]

response = self.client.messages.create(
model=self.model_name,
max_tokens=8192,
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)
)
response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
else:
import vertexai.generative_models as glm
from transformers import GenerationConfig

if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
try:
@@ -1210,20 +1218,15 @@ class GoogleCV(Base):
if his["role"] == "user":
his["parts"] = [his["content"]]
his.pop("content")
# Create proper image part for Gemini
img_bytes = base64.b64decode(image)
image_part = glm.Part.from_data(
data=img_bytes,
mime_type="image/jpeg"
)
image_part = glm.Part.from_data(data=img_bytes, mime_type="image/jpeg")
history[-1]["parts"].append(image_part)

response = self.client.generate_content(history, generation_config=GenerationConfig(
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)))
response = self.client.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))

ans = response.text
return ans, response.usage_metadata.total_token_count
except Exception as e:
return "**ERROR**: " + str(e), 0
return "**ERROR**: " + str(e), 0

+ 152
- 193
rag/llm/embedding_model.py 查看文件

@@ -13,28 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
import re
import threading
from abc import ABC
from urllib.parse import urljoin

import dashscope
import google.generativeai as genai
import numpy as np
import requests
from huggingface_hub import snapshot_download
from zhipuai import ZhipuAI
import os
from abc import ABC
from ollama import Client
import dashscope
from openai import OpenAI
import numpy as np
import asyncio
from zhipuai import ZhipuAI

from api import settings
from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai
import json


class Base(ABC):
@@ -60,7 +59,8 @@ class Base(ABC):


class DefaultEmbedding(Base):
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
_FACTORY_NAME = "BAAI"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
_model = None
_model_name = ""
_model_lock = threading.Lock()
@@ -79,21 +79,22 @@ class DefaultEmbedding(Base):
"""
if not settings.LIGHTEN:
with DefaultEmbedding._model_lock:
from FlagEmbedding import FlagModel
import torch
from FlagEmbedding import FlagModel

if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
DefaultEmbedding._model = FlagModel(
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available(),
)
DefaultEmbedding._model_name = model_name
except Exception:
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
local_dir_use_symlinks=False)
DefaultEmbedding._model = FlagModel(model_dir,
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
model_dir = snapshot_download(
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
)
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
self._model = DefaultEmbedding._model
self._model_name = DefaultEmbedding._model_name

@@ -105,7 +106,7 @@ class DefaultEmbedding(Base):
token_count += num_tokens_from_string(t)
ress = []
for i in range(0, len(texts), batch_size):
ress.extend(self._model.encode(texts[i:i + batch_size]).tolist())
ress.extend(self._model.encode(texts[i : i + batch_size]).tolist())
return np.array(ress), token_count

def encode_queries(self, text: str):
@@ -114,8 +115,9 @@ class DefaultEmbedding(Base):


class OpenAIEmbed(Base):
def __init__(self, key, model_name="text-embedding-ada-002",
base_url="https://api.openai.com/v1"):
_FACTORY_NAME = "OpenAI"

def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
@@ -128,8 +130,7 @@ class OpenAIEmbed(Base):
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i:i + batch_size],
model=self.model_name)
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
@@ -138,12 +139,13 @@ class OpenAIEmbed(Base):
return np.array(ress), total_tokens

def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8191)],
model=self.model_name)
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count(res)


class LocalAIEmbed(Base):
_FACTORY_NAME = "LocalAI"

def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local embedding model url cannot be None")
@@ -155,7 +157,7 @@ class LocalAIEmbed(Base):
batch_size = 16
ress = []
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
except Exception as _e:
@@ -169,41 +171,42 @@ class LocalAIEmbed(Base):


class AzureEmbed(OpenAIEmbed):
_FACTORY_NAME = "Azure-OpenAI"

def __init__(self, key, model_name, **kwargs):
from openai.lib.azure import AzureOpenAI
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01')

api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name


class BaiChuanEmbed(OpenAIEmbed):
def __init__(self, key,
model_name='Baichuan-Text-Embedding',
base_url='https://api.baichuan-ai.com/v1'):
_FACTORY_NAME = "BaiChuan"
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url)


class QWenEmbed(Base):
_FACTORY_NAME = "Tongyi-Qianwen"

def __init__(self, key, model_name="text_embedding_v2", **kwargs):
self.key = key
self.model_name = model_name

def encode(self, texts: list):
import dashscope

batch_size = 4
res = []
token_count = 0
texts = [truncate(t, 2048) for t in texts]
for i in range(0, len(texts), batch_size):
resp = dashscope.TextEmbedding.call(
model=self.model_name,
input=texts[i:i + batch_size],
api_key=self.key,
text_type="document"
)
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
try:
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
for e in resp["output"]["embeddings"]:
@@ -216,20 +219,16 @@ class QWenEmbed(Base):
return np.array(res), token_count

def encode_queries(self, text):
resp = dashscope.TextEmbedding.call(
model=self.model_name,
input=text[:2048],
api_key=self.key,
text_type="query"
)
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
try:
return np.array(resp["output"]["embeddings"][0]
["embedding"]), self.total_token_count(resp)
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
except Exception as _e:
log_exception(_e, resp)


class ZhipuEmbed(Base):
_FACTORY_NAME = "ZHIPU-AI"

def __init__(self, key, model_name="embedding-2", **kwargs):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
@@ -246,8 +245,7 @@ class ZhipuEmbed(Base):
texts = [truncate(t, MAX_LEN) for t in texts]

for txt in texts:
res = self.client.embeddings.create(input=txt,
model=self.model_name)
res = self.client.embeddings.create(input=txt, model=self.model_name)
try:
arr.append(res.data[0].embedding)
tks_num += self.total_token_count(res)
@@ -256,8 +254,7 @@ class ZhipuEmbed(Base):
return np.array(arr), tks_num

def encode_queries(self, text):
res = self.client.embeddings.create(input=text,
model=self.model_name)
res = self.client.embeddings.create(input=text, model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
@@ -265,18 +262,17 @@ class ZhipuEmbed(Base):


class OllamaEmbed(Base):
_FACTORY_NAME = "Ollama"

def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \
Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.model_name = model_name

def encode(self, texts: list):
arr = []
tks_num = 0
for txt in texts:
res = self.client.embeddings(prompt=txt,
model=self.model_name,
options={"use_mmap": True})
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True})
try:
arr.append(res["embedding"])
except Exception as _e:
@@ -285,9 +281,7 @@ class OllamaEmbed(Base):
return np.array(arr), tks_num

def encode_queries(self, text):
res = self.client.embeddings(prompt=text,
model=self.model_name,
options={"use_mmap": True})
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True})
try:
return np.array(res["embedding"]), 128
except Exception as _e:
@@ -295,27 +289,28 @@ class OllamaEmbed(Base):


class FastEmbed(DefaultEmbedding):
_FACTORY_NAME = "FastEmbed"

def __init__(
self,
key: str | None = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: str | None = None,
threads: int | None = None,
**kwargs,
self,
key: str | None = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: str | None = None,
threads: int | None = None,
**kwargs,
):
if not settings.LIGHTEN:
with FastEmbed._model_lock:
from fastembed import TextEmbedding

if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
DefaultEmbedding._model_name = model_name
except Exception:
cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
local_dir=os.path.join(get_home_cache_dir(),
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
local_dir_use_symlinks=False)
cache_dir = snapshot_download(
repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
)
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
self._model = DefaultEmbedding._model
self._model_name = model_name
@@ -340,6 +335,8 @@ class FastEmbed(DefaultEmbedding):


class XinferenceEmbed(Base):
_FACTORY_NAME = "Xinference"

def __init__(self, key, model_name="", base_url=""):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
@@ -350,7 +347,7 @@ class XinferenceEmbed(Base):
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
@@ -359,8 +356,7 @@ class XinferenceEmbed(Base):
return np.array(ress), total_tokens

def encode_queries(self, text):
res = self.client.embeddings.create(input=[text],
model=self.model_name)
res = self.client.embeddings.create(input=[text], model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
@@ -368,20 +364,18 @@ class XinferenceEmbed(Base):


class YoudaoEmbed(Base):
_FACTORY_NAME = "Youdao"
_client = None

def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
if not settings.LIGHTEN and not YoudaoEmbed._client:
from BCEmbedding import EmbeddingModel as qanthing

try:
logging.info("LOADING BCE...")
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
get_home_cache_dir(),
"bce-embedding-base_v1"))
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
except Exception:
YoudaoEmbed._client = qanthing(
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))
YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))

def encode(self, texts: list):
batch_size = 10
@@ -390,7 +384,7 @@ class YoudaoEmbed(Base):
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
res.extend(embds)
return np.array(res), token_count

@@ -400,14 +394,11 @@ class YoudaoEmbed(Base):


class JinaEmbed(Base):
def __init__(self, key, model_name="jina-embeddings-v3",
base_url="https://api.jina.ai/v1/embeddings"):
_FACTORY_NAME = "Jina"

def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
self.base_url = "https://api.jina.ai/v1/embeddings"
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name

def encode(self, texts: list):
@@ -416,11 +407,7 @@ class JinaEmbed(Base):
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
data = {
"model": self.model_name,
"input": texts[i:i + batch_size],
'encoding_type': 'float'
}
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
response = requests.post(self.base_url, headers=self.headers, json=data)
try:
res = response.json()
@@ -435,50 +422,12 @@ class JinaEmbed(Base):
return np.array(embds[0]), cnt


class InfinityEmbed(Base):
_model = None

def __init__(
self,
model_names: list[str] = ("BAAI/bge-small-en-v1.5",),
engine_kwargs: dict = {},
key = None,
):

from infinity_emb import EngineArgs
from infinity_emb.engine import AsyncEngineArray

self._default_model = model_names[0]
self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])

async def _embed(self, sentences: list[str], model_name: str = ""):
if not model_name:
model_name = self._default_model
engine = self.engine_array[model_name]
was_already_running = engine.is_running
if not was_already_running:
await engine.astart()
embeddings, usage = await engine.embed(sentences=sentences)
if not was_already_running:
await engine.astop()
return embeddings, usage

def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]:
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
embeddings, usage = asyncio.run(self._embed(texts, model_name))
return np.array(embeddings), usage

def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
return self.encode([text])


class MistralEmbed(Base):
def __init__(self, key, model_name="mistral-embed",
base_url=None):
_FACTORY_NAME = "Mistral"

def __init__(self, key, model_name="mistral-embed", base_url=None):
from mistralai.client import MistralClient

self.client = MistralClient(api_key=key)
self.model_name = model_name

@@ -488,8 +437,7 @@ class MistralEmbed(Base):
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
res = self.client.embeddings(input=texts[i:i + batch_size],
model=self.model_name)
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
token_count += self.total_token_count(res)
@@ -498,8 +446,7 @@ class MistralEmbed(Base):
return np.array(ress), token_count

def encode_queries(self, text):
res = self.client.embeddings(input=[truncate(text, 8196)],
model=self.model_name)
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
@@ -507,30 +454,31 @@ class MistralEmbed(Base):


class BedrockEmbed(Base):
def __init__(self, key, model_name,
**kwargs):
_FACTORY_NAME = "Bedrock"

def __init__(self, key, model_name, **kwargs):
import boto3
self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
self.bedrock_region = json.loads(key).get('bedrock_region', '')

self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "")
self.model_name = model_name
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
self.client = boto3.client('bedrock-runtime')
self.client = boto3.client("bedrock-runtime")
else:
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)

def encode(self, texts: list):
texts = [truncate(t, 8196) for t in texts]
embeddings = []
token_count = 0
for text in texts:
if self.model_name.split('.')[0] == 'amazon':
if self.model_name.split(".")[0] == "amazon":
body = {"inputText": text}
elif self.model_name.split('.')[0] == 'cohere':
body = {"texts": [text], "input_type": 'search_document'}
elif self.model_name.split(".")[0] == "cohere":
body = {"texts": [text], "input_type": "search_document"}

response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try:
@@ -545,10 +493,10 @@ class BedrockEmbed(Base):
def encode_queries(self, text):
embeddings = []
token_count = num_tokens_from_string(text)
if self.model_name.split('.')[0] == 'amazon':
if self.model_name.split(".")[0] == "amazon":
body = {"inputText": truncate(text, 8196)}
elif self.model_name.split('.')[0] == 'cohere':
body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'}
elif self.model_name.split(".")[0] == "cohere":
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}

response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try:
@@ -561,11 +509,12 @@ class BedrockEmbed(Base):


class GeminiEmbed(Base):
def __init__(self, key, model_name='models/text-embedding-004',
**kwargs):
_FACTORY_NAME = "Gemini"

def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
self.key = key
self.model_name = 'models/' + model_name
self.model_name = "models/" + model_name
def encode(self, texts: list):
texts = [truncate(t, 2048) for t in texts]
token_count = sum(num_tokens_from_string(text) for text in texts)
@@ -573,35 +522,27 @@ class GeminiEmbed(Base):
batch_size = 16
ress = []
for i in range(0, len(texts), batch_size):
result = genai.embed_content(
model=self.model_name,
content=texts[i: i + batch_size],
task_type="retrieval_document",
title="Embedding of single string")
result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
try:
ress.extend(result['embedding'])
ress.extend(result["embedding"])
except Exception as _e:
log_exception(_e, result)
return np.array(ress),token_count
return np.array(ress), token_count
def encode_queries(self, text):
genai.configure(api_key=self.key)
result = genai.embed_content(
model=self.model_name,
content=truncate(text,2048),
task_type="retrieval_document",
title="Embedding of single string")
result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
token_count = num_tokens_from_string(text)
try:
return np.array(result['embedding']), token_count
return np.array(result["embedding"]), token_count
except Exception as _e:
log_exception(_e, result)


class NvidiaEmbed(Base):
def __init__(
self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"
):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
if not base_url:
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
self.api_key = key
@@ -645,6 +586,8 @@ class NvidiaEmbed(Base):


class LmStudioEmbed(LocalAIEmbed):
_FACTORY_NAME = "LM-Studio"

def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -654,6 +597,8 @@ class LmStudioEmbed(LocalAIEmbed):


class OpenAI_APIEmbed(OpenAIEmbed):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]

def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
@@ -663,6 +608,8 @@ class OpenAI_APIEmbed(OpenAIEmbed):


class CoHereEmbed(Base):
_FACTORY_NAME = "Cohere"

def __init__(self, key, model_name, base_url=None):
from cohere import Client

@@ -701,6 +648,8 @@ class CoHereEmbed(Base):


class TogetherAIEmbed(OpenAIEmbed):
_FACTORY_NAME = "TogetherAI"

def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
@@ -708,6 +657,8 @@ class TogetherAIEmbed(OpenAIEmbed):


class PerfXCloudEmbed(OpenAIEmbed):
_FACTORY_NAME = "PerfXCloud"

def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
if not base_url:
base_url = "https://cloud.perfxlab.cn/v1"
@@ -715,6 +666,8 @@ class PerfXCloudEmbed(OpenAIEmbed):


class UpstageEmbed(OpenAIEmbed):
_FACTORY_NAME = "Upstage"

def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
if not base_url:
base_url = "https://api.upstage.ai/v1/solar"
@@ -722,6 +675,8 @@ class UpstageEmbed(OpenAIEmbed):


class SILICONFLOWEmbed(Base):
_FACTORY_NAME = "SILICONFLOW"

def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/embeddings"
@@ -769,6 +724,8 @@ class SILICONFLOWEmbed(Base):


class ReplicateEmbed(Base):
_FACTORY_NAME = "Replicate"

def __init__(self, key, model_name, base_url=None):
from replicate.client import Client

@@ -790,6 +747,8 @@ class ReplicateEmbed(Base):


class BaiduYiyanEmbed(Base):
_FACTORY_NAME = "BaiduYiyan"

def __init__(self, key, model_name, base_url=None):
import qianfan

@@ -821,6 +780,8 @@ class BaiduYiyanEmbed(Base):


class VoyageEmbed(Base):
_FACTORY_NAME = "Voyage AI"

def __init__(self, key, model_name, base_url=None):
import voyageai

@@ -832,9 +793,7 @@ class VoyageEmbed(Base):
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
res = self.client.embed(
texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
)
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
try:
ress.extend(res.embeddings)
token_count += res.total_tokens
@@ -843,9 +802,7 @@ class VoyageEmbed(Base):
return np.array(ress), token_count

def encode_queries(self, text):
res = self.client.embed(
texts=text, model=self.model_name, input_type="query"
)
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
try:
return np.array(res.embeddings)[0], res.total_tokens
except Exception as _e:
@@ -853,6 +810,8 @@ class VoyageEmbed(Base):


class HuggingFaceEmbed(Base):
_FACTORY_NAME = "HuggingFace"

def __init__(self, key, model_name, base_url=None):
if not model_name:
raise ValueError("Model name cannot be None")
@@ -863,11 +822,7 @@ class HuggingFaceEmbed(Base):
def encode(self, texts: list):
embeddings = []
for text in texts:
response = requests.post(
f"{self.base_url}/embed",
json={"inputs": text},
headers={'Content-Type': 'application/json'}
)
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embedding = response.json()
embeddings.append(embedding[0])
@@ -876,11 +831,7 @@ class HuggingFaceEmbed(Base):
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])

def encode_queries(self, text):
response = requests.post(
f"{self.base_url}/embed",
json={"inputs": text},
headers={'Content-Type': 'application/json'}
)
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embedding = response.json()
return np.array(embedding[0]), num_tokens_from_string(text)
@@ -889,15 +840,19 @@ class HuggingFaceEmbed(Base):


class VolcEngineEmbed(OpenAIEmbed):
_FACTORY_NAME = "VolcEngine"

def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
if not base_url:
base_url = "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get('ark_api_key', '')
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
super().__init__(ark_api_key,model_name,base_url)
ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url)


class GPUStackEmbed(OpenAIEmbed):
_FACTORY_NAME = "GPUStack"

def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
@@ -908,6 +863,8 @@ class GPUStackEmbed(OpenAIEmbed):


class NovitaEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "NovitaAI"

def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/embeddings"
@@ -915,7 +872,9 @@ class NovitaEmbed(SILICONFLOWEmbed):


class GiteeEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "GiteeAI"

def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
if not base_url:
base_url = "https://ai.gitee.com/v1/embeddings"
super().__init__(key, model_name, base_url)
super().__init__(key, model_name, base_url)

+ 74
- 92
rag/llm/rerank_model.py 查看文件

@@ -13,24 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import re
import threading
from abc import ABC
from collections.abc import Iterable
from urllib.parse import urljoin

import requests
import httpx
from huggingface_hub import snapshot_download
import os
from abc import ABC
import numpy as np
import requests
from huggingface_hub import snapshot_download
from yarl import URL

from api import settings
from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate
import json


def sigmoid(x):
@@ -57,6 +57,7 @@ class Base(ABC):


class DefaultRerank(Base):
_FACTORY_NAME = "BAAI"
_model = None
_model_lock = threading.Lock()

@@ -75,17 +76,13 @@ class DefaultRerank(Base):
if not settings.LIGHTEN and not DefaultRerank._model:
import torch
from FlagEmbedding import FlagReranker

with DefaultRerank._model_lock:
if not DefaultRerank._model:
try:
DefaultRerank._model = FlagReranker(
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
use_fp16=torch.cuda.is_available())
DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
except Exception:
model_dir = snapshot_download(repo_id=model_name,
local_dir=os.path.join(get_home_cache_dir(),
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
local_dir_use_symlinks=False)
model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
self._model = DefaultRerank._model
self._dynamic_batch_size = 8
@@ -94,6 +91,7 @@ class DefaultRerank(Base):
def torch_empty_cache(self):
try:
import torch

torch.cuda.empty_cache()
except Exception as e:
print(f"Error emptying cache: {e}")
@@ -112,7 +110,7 @@ class DefaultRerank(Base):
while retry_count < max_retries:
try:
# call subclass implemented batch processing calculation
batch_scores = self._compute_batch_scores(pairs[i:i + current_batch])
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
res.extend(batch_scores)
i += current_batch
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
@@ -152,23 +150,16 @@ class DefaultRerank(Base):


class JinaRerank(Base):
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual",
base_url="https://api.jina.ai/v1/rerank"):
_FACTORY_NAME = "Jina"

def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
self.base_url = "https://api.jina.ai/v1/rerank"
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name

def similarity(self, query: str, texts: list):
texts = [truncate(t, 8196) for t in texts]
data = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts)
}
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
@@ -180,22 +171,20 @@ class JinaRerank(Base):


class YoudaoRerank(DefaultRerank):
_FACTORY_NAME = "Youdao"
_model = None
_model_lock = threading.Lock()

def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
if not settings.LIGHTEN and not YoudaoRerank._model:
from BCEmbedding import RerankerModel

with YoudaoRerank._model_lock:
if not YoudaoRerank._model:
try:
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
get_home_cache_dir(),
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
except Exception:
YoudaoRerank._model = RerankerModel(
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))
YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))

self._model = YoudaoRerank._model
self._dynamic_batch_size = 8
@@ -212,6 +201,8 @@ class YoudaoRerank(DefaultRerank):


class XInferenceRerank(Base):
_FACTORY_NAME = "Xinference"

def __init__(self, key="x", model_name="", base_url=""):
if base_url.find("/v1") == -1:
base_url = urljoin(base_url, "/v1/rerank")
@@ -219,10 +210,7 @@ class XInferenceRerank(Base):
base_url = urljoin(base_url, "/v1/rerank")
self.model_name = model_name
self.base_url = base_url
self.headers = {
"Content-Type": "application/json",
"accept": "application/json"
}
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bearer {key}"

@@ -233,13 +221,7 @@ class XInferenceRerank(Base):
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
data = {
"model": self.model_name,
"query": query,
"return_documents": "true",
"return_len": "true",
"documents": texts
}
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
@@ -251,15 +233,14 @@ class XInferenceRerank(Base):


class LocalAIRerank(Base):
_FACTORY_NAME = "LocalAI"

def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank")
else:
self.base_url = base_url
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]

def similarity(self, query: str, texts: list):
@@ -296,16 +277,15 @@ class LocalAIRerank(Base):


class NvidiaRerank(Base):
def __init__(
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
if not base_url:
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
self.model_name = model_name

if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking"
)
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")

if self.model_name == "nvidia/rerank-qa-mistral-4b":
self.base_url = urljoin(base_url, "reranking")
@@ -318,9 +298,7 @@ class NvidiaRerank(Base):
}

def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum(
[num_tokens_from_string(t) for t in texts]
)
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
data = {
"model": self.model_name,
"query": {"text": query},
@@ -339,6 +317,8 @@ class NvidiaRerank(Base):


class LmStudioRerank(Base):
_FACTORY_NAME = "LM-Studio"

def __init__(self, key, model_name, base_url):
pass

@@ -347,15 +327,14 @@ class LmStudioRerank(Base):


class OpenAI_APIRerank(Base):
_FACTORY_NAME = "OpenAI-API-Compatible"

def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank")
else:
self.base_url = base_url
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]

def similarity(self, query: str, texts: list):
@@ -392,6 +371,8 @@ class OpenAI_APIRerank(Base):


class CoHereRerank(Base):
_FACTORY_NAME = ["Cohere", "VLLM"]

def __init__(self, key, model_name, base_url=None):
from cohere import Client

@@ -399,9 +380,7 @@ class CoHereRerank(Base):
self.model_name = model_name.split("___")[0]

def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum(
[num_tokens_from_string(t) for t in texts]
)
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
res = self.client.rerank(
model=self.model_name,
query=query,
@@ -419,6 +398,8 @@ class CoHereRerank(Base):


class TogetherAIRerank(Base):
_FACTORY_NAME = "TogetherAI"

def __init__(self, key, model_name, base_url):
pass

@@ -427,9 +408,9 @@ class TogetherAIRerank(Base):


class SILICONFLOWRerank(Base):
def __init__(
self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/rerank"
self.model_name = model_name
@@ -450,9 +431,7 @@ class SILICONFLOWRerank(Base):
"max_chunks_per_doc": 1024,
"overlap_tokens": 80,
}
response = requests.post(
self.base_url, json=payload, headers=self.headers
).json()
response = requests.post(self.base_url, json=payload, headers=self.headers).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in response["results"]:
@@ -466,6 +445,8 @@ class SILICONFLOWRerank(Base):


class BaiduYiyanRerank(Base):
_FACTORY_NAME = "BaiduYiyan"

def __init__(self, key, model_name, base_url=None):
from qianfan.resources import Reranker

@@ -492,6 +473,8 @@ class BaiduYiyanRerank(Base):


class VoyageRerank(Base):
_FACTORY_NAME = "Voyage AI"

def __init__(self, key, model_name, base_url=None):
import voyageai

@@ -502,9 +485,7 @@ class VoyageRerank(Base):
rank = np.zeros(len(texts), dtype=float)
if not texts:
return rank, 0
res = self.client.rerank(
query=query, documents=texts, model=self.model_name, top_k=len(texts)
)
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
try:
for r in res.results:
rank[r.index] = r.relevance_score
@@ -514,22 +495,20 @@ class VoyageRerank(Base):


class QWenRerank(Base):
def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
_FACTORY_NAME = "Tongyi-Qianwen"

def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
import dashscope

self.api_key = key
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name

def similarity(self, query: str, texts: list):
import dashscope
from http import HTTPStatus
resp = dashscope.TextReRank.call(
api_key=self.api_key,
model=self.model_name,
query=query,
documents=texts,
top_n=len(texts),
return_documents=False
)

import dashscope

resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
rank = np.zeros(len(texts), dtype=float)
if resp.status_code == HTTPStatus.OK:
try:
@@ -543,6 +522,8 @@ class QWenRerank(Base):


class HuggingfaceRerank(DefaultRerank):
_FACTORY_NAME = "HuggingFace"

@staticmethod
def post(query: str, texts: list, url="127.0.0.1"):
exc = None
@@ -550,9 +531,9 @@ class HuggingfaceRerank(DefaultRerank):
batch_size = 8
for i in range(0, len(texts), batch_size):
try:
res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"},
json={"query": query, "texts": texts[i: i + batch_size],
"raw_scores": False, "truncate": True})
res = requests.post(
f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
)

for o in res.json():
scores[o["index"] + i] = o["score"]
@@ -577,9 +558,9 @@ class HuggingfaceRerank(DefaultRerank):


class GPUStackRerank(Base):
def __init__(
self, key, model_name, base_url
):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")

@@ -600,9 +581,7 @@ class GPUStackRerank(Base):
}

try:
response = requests.post(
self.base_url, json=payload, headers=self.headers
)
response = requests.post(self.base_url, json=payload, headers=self.headers)
response.raise_for_status()
response_json = response.json()

@@ -623,11 +602,12 @@ class GPUStackRerank(Base):
)

except httpx.HTTPStatusError as e:
raise ValueError(
f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")


class NovitaRerank(JinaRerank):
_FACTORY_NAME = "NovitaAI"

def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/rerank"
@@ -635,7 +615,9 @@ class NovitaRerank(JinaRerank):


class GiteeRerank(JinaRerank):
_FACTORY_NAME = "GiteeAI"

def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
if not base_url:
base_url = "https://ai.gitee.com/v1/rerank"
super().__init__(key, model_name, base_url)
super().__init__(key, model_name, base_url)

+ 42
- 45
rag/llm/sequence2txt_model.py 查看文件

@@ -13,16 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import requests
from openai.lib.azure import AzureOpenAI
import base64
import io
import json
import os
import re
from abc import ABC

import requests
from openai import OpenAI
import json
from openai.lib.azure import AzureOpenAI

from rag.utils import num_tokens_from_string
import base64
import re


class Base(ABC):
@@ -30,11 +32,7 @@ class Base(ABC):
pass

def transcription(self, audio, **kwargs):
transcription = self.client.audio.transcriptions.create(
model=self.model_name,
file=audio,
response_format="text"
)
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio, response_format="text")
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())

def audio2base64(self, audio):
@@ -46,6 +44,8 @@ class Base(ABC):


class GPTSeq2txt(Base):
_FACTORY_NAME = "OpenAI"

def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
@@ -54,31 +54,34 @@ class GPTSeq2txt(Base):


class QWenSeq2txt(Base):
_FACTORY_NAME = "Tongyi-Qianwen"

def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
import dashscope

dashscope.api_key = key
self.model_name = model_name

def transcription(self, audio, format):
from http import HTTPStatus

from dashscope.audio.asr import Recognition

recognition = Recognition(model=self.model_name,
format=format,
sample_rate=16000,
callback=None)
recognition = Recognition(model=self.model_name, format=format, sample_rate=16000, callback=None)
result = recognition.call(audio)

ans = ""
if result.status_code == HTTPStatus.OK:
for sentence in result.get_sentence():
ans += sentence.text.decode('utf-8') + '\n'
ans += sentence.text.decode("utf-8") + "\n"
return ans, num_tokens_from_string(ans)

return "**ERROR**: " + result.message, 0


class AzureSeq2txt(Base):
_FACTORY_NAME = "Azure-OpenAI"

def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
self.model_name = model_name
@@ -86,43 +89,33 @@ class AzureSeq2txt(Base):


class XinferenceSeq2txt(Base):
_FACTORY_NAME = "Xinference"

def __init__(self, key, model_name="whisper-small", **kwargs):
self.base_url = kwargs.get('base_url', None)
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.key = key

def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
if isinstance(audio, str):
audio_file = open(audio, 'rb')
audio_file = open(audio, "rb")
audio_data = audio_file.read()
audio_file_name = audio.split("/")[-1]
else:
audio_data = audio
audio_file_name = "audio.wav"

payload = {
"model": self.model_name,
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature
}
payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}

files = {
"file": (audio_file_name, audio_data, 'audio/wav')
}
files = {"file": (audio_file_name, audio_data, "audio/wav")}

try:
response = requests.post(
f"{self.base_url}/v1/audio/transcriptions",
files=files,
data=payload
)
response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
response.raise_for_status()
result = response.json()

if 'text' in result:
transcription_text = result['text'].strip()
if "text" in result:
transcription_text = result["text"].strip()
return transcription_text, num_tokens_from_string(transcription_text)
else:
return "**ERROR**: Failed to retrieve transcription.", 0
@@ -132,11 +125,11 @@ class XinferenceSeq2txt(Base):


class TencentCloudSeq2txt(Base):
def __init__(
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
):
from tencentcloud.common import credential
_FACTORY_NAME = "Tencent Cloud"

def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
from tencentcloud.asr.v20190614 import asr_client
from tencentcloud.common import credential

key = json.loads(key)
sid = key.get("tencent_cloud_sid", "")
@@ -146,11 +139,12 @@ class TencentCloudSeq2txt(Base):
self.model_name = model_name

def transcription(self, audio, max_retries=60, retry_interval=5):
import time

from tencentcloud.asr.v20190614 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
from tencentcloud.asr.v20190614 import models
import time

b64 = self.audio2base64(audio)
try:
@@ -174,9 +168,7 @@ class TencentCloudSeq2txt(Base):
while retries < max_retries:
resp = self.client.DescribeTaskStatus(req)
if resp.Data.StatusStr == "success":
text = re.sub(
r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result
).strip()
text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
return text, num_tokens_from_string(text)
elif resp.Data.StatusStr == "failed":
return (
@@ -195,6 +187,8 @@ class TencentCloudSeq2txt(Base):


class GPUStackSeq2txt(Base):
_FACTORY_NAME = "GPUStack"

def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
@@ -206,8 +200,11 @@ class GPUStackSeq2txt(Base):


class GiteeSeq2txt(Base):
_FACTORY_NAME = "GiteeAI"

def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"):
if not base_url:
base_url = "https://ai.gitee.com/v1/"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.model_name = model_name


+ 50
- 92
rag/llm/tts_model.py 查看文件

@@ -70,10 +70,12 @@ class Base(ABC):
pass

def normalize_text(self, text):
return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)


class FishAudioTTS(Base):
_FACTORY_NAME = "Fish Audio"

def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
if not base_url:
base_url = "https://api.fish.audio/v1/tts"
@@ -94,13 +96,11 @@ class FishAudioTTS(Base):
with httpx.Client() as client:
try:
with client.stream(
method="POST",
url=self.base_url,
content=ormsgpack.packb(
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
),
headers=self.headers,
timeout=None,
method="POST",
url=self.base_url,
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
headers=self.headers,
timeout=None,
) as response:
if response.status_code == HTTPStatus.OK:
for chunk in response.iter_bytes():
@@ -115,6 +115,8 @@ class FishAudioTTS(Base):


class QwenTTS(Base):
_FACTORY_NAME = "Tongyi-Qianwen"

def __init__(self, key, model_name, base_url=""):
import dashscope

@@ -122,10 +124,11 @@ class QwenTTS(Base):
dashscope.api_key = key

def tts(self, text):
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
from collections import deque

from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer

class Callback(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
@@ -159,10 +162,7 @@ class QwenTTS(Base):

text = self.normalize_text(text)
callback = Callback()
SpeechSynthesizer.call(model=self.model_name,
text=text,
callback=callback,
format="mp3")
SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
try:
for data in callback._run():
yield data
@@ -173,24 +173,19 @@ class QwenTTS(Base):


class OpenAITTS(Base):
_FACTORY_NAME = "OpenAI"

def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

def tts(self, text, voice="alloy"):
text = self.normalize_text(text)
payload = {
"model": self.model_name,
"voice": voice,
"input": text
}
payload = {"model": self.model_name, "voice": voice, "input": text}

response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)

@@ -201,7 +196,8 @@ class OpenAITTS(Base):
yield chunk


class SparkTTS:
class SparkTTS(Base):
_FACTORY_NAME = "XunFei Spark"
STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2
@@ -219,29 +215,23 @@ class SparkTTS:

# 生成url
def create_url(self):
url = 'wss://tts-api.xfyun.cn/v2/tts'
url = "wss://tts-api.xfyun.cn/v2/tts"
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": "ws-api.xfyun.cn"
}
url = url + '?' + urlencode(v)
signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
url = url + "?" + urlencode(v)
return url

def tts(self, text):
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
CommonArgs = {"app_id": self.APPID}
audio_queue = self.audio_queue
model_name = self.model_name
@@ -273,9 +263,7 @@ class SparkTTS:

def on_open(self, ws):
def run(*args):
d = {"common": CommonArgs,
"business": BusinessArgs,
"data": Data}
d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
ws.send(json.dumps(d))

thread.start_new_thread(run, ())
@@ -283,44 +271,32 @@ class SparkTTS:
wsUrl = self.create_url()
websocket.enableTrace(False)
a = Callback()
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
on_message=a.on_message)
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
status_code = 0
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
while True:
audio_chunk = self.audio_queue.get()
if audio_chunk is None:
if status_code == 0:
raise Exception(
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
else:
break
status_code = 1
yield audio_chunk


class XinferenceTTS:
class XinferenceTTS(Base):
_FACTORY_NAME = "Xinference"

def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.headers = {
"accept": "application/json",
"Content-Type": "application/json"
}
self.headers = {"accept": "application/json", "Content-Type": "application/json"}

def tts(self, text, voice="中文女", stream=True):
payload = {
"model": self.model_name,
"input": text,
"voice": voice
}
payload = {"model": self.model_name, "input": text, "voice": voice}

response = requests.post(
f"{self.base_url}/v1/audio/speech",
headers=self.headers,
json=payload,
stream=stream
)
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
@@ -332,22 +308,16 @@ class XinferenceTTS:

class OllamaTTS(Base):
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
if not base_url:
if not base_url:
base_url = "https://api.ollama.ai/v1"
self.model_name = model_name
self.base_url = base_url
self.headers = {
"Content-Type": "application/json"
}
self.headers = {"Content-Type": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bear {key}"

def tts(self, text, voice="standard-voice"):
payload = {
"model": self.model_name,
"voice": voice,
"input": text
}
payload = {"model": self.model_name, "voice": voice, "input": text}

response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)

@@ -359,30 +329,19 @@ class OllamaTTS(Base):
yield chunk


class GPUStackTTS:
class GPUStackTTS(Base):
_FACTORY_NAME = "GPUStack"

def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.api_key = key
self.model_name = model_name
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}

def tts(self, text, voice="Chinese Female", stream=True):
payload = {
"model": self.model_name,
"input": text,
"voice": voice
}
payload = {"model": self.model_name, "input": text, "voice": voice}

response = requests.post(
f"{self.base_url}/v1/audio/speech",
headers=self.headers,
json=payload,
stream=stream
)
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
@@ -393,16 +352,15 @@ class GPUStackTTS:


class SILICONFLOWTTS(Base):
_FACTORY_NAME = "SILICONFLOW"

def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

def tts(self, text, voice="anna"):
text = self.normalize_text(text)
@@ -414,7 +372,7 @@ class SILICONFLOWTTS(Base):
"sample_rate": 123,
"stream": True,
"speed": 1,
"gain": 0
"gain": 0,
}

response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)

正在加载...
取消
保存