浏览代码

Refa: add model. (#5820)

### What problem does this PR solve?

#5783

### Type of change

- [x] Refactoring
tags/v0.17.1
Kevin Hu 7 个月前
父节点
当前提交
82f5d901c8
共有 1 个文件被更改,包括 7 次插入20 次删除
  1. 7
    20
      api/apps/llm_app.py

+ 7
- 20
api/apps/llm_app.py 查看文件

def add_llm(): def add_llm():
req = request.json req = request.json
factory = req["llm_factory"] factory = req["llm_factory"]
api_key = req.get("api_key", "")
llm_name = req["llm_name"]


def apikey_json(keys): def apikey_json(keys):
nonlocal req nonlocal req
if factory == "VolcEngine": if factory == "VolcEngine":
# For VolcEngine, due to its special authentication method # For VolcEngine, due to its special authentication method
# Assemble ark_api_key endpoint_id into api_key # Assemble ark_api_key endpoint_id into api_key
llm_name = req["llm_name"]
api_key = apikey_json(["ark_api_key", "endpoint_id"]) api_key = apikey_json(["ark_api_key", "endpoint_id"])


elif factory == "Tencent Hunyuan": elif factory == "Tencent Hunyuan":
elif factory == "Bedrock": elif factory == "Bedrock":
# For Bedrock, due to its special authentication method # For Bedrock, due to its special authentication method
# Assemble bedrock_ak, bedrock_sk, bedrock_region # Assemble bedrock_ak, bedrock_sk, bedrock_region
llm_name = req["llm_name"]
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"]) api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])


elif factory == "LocalAI": elif factory == "LocalAI":
llm_name = req["llm_name"] + "___LocalAI"
api_key = "xxxxxxxxxxxxxxx"
llm_name += "___LocalAI"


elif factory == "HuggingFace": elif factory == "HuggingFace":
llm_name = req["llm_name"] + "___HuggingFace"
api_key = "xxxxxxxxxxxxxxx"
llm_name += "___HuggingFace"


elif factory == "OpenAI-API-Compatible": elif factory == "OpenAI-API-Compatible":
llm_name = req["llm_name"] + "___OpenAI-API"
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
llm_name += "___OpenAI-API"


elif factory == "VLLM": elif factory == "VLLM":
llm_name = req["llm_name"] + "___VLLM"
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
llm_name += "___VLLM"


elif factory == "XunFei Spark": elif factory == "XunFei Spark":
llm_name = req["llm_name"]
if req["model_type"] == "chat": if req["model_type"] == "chat":
api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx")
api_key = req.get("spark_api_password", "")
elif req["model_type"] == "tts": elif req["model_type"] == "tts":
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"]) api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])


elif factory == "BaiduYiyan": elif factory == "BaiduYiyan":
llm_name = req["llm_name"]
api_key = apikey_json(["yiyan_ak", "yiyan_sk"]) api_key = apikey_json(["yiyan_ak", "yiyan_sk"])


elif factory == "Fish Audio": elif factory == "Fish Audio":
llm_name = req["llm_name"]
api_key = apikey_json(["fish_audio_ak", "fish_audio_refid"]) api_key = apikey_json(["fish_audio_ak", "fish_audio_refid"])


elif factory == "Google Cloud": elif factory == "Google Cloud":
llm_name = req["llm_name"]
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"]) api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])


elif factory == "Azure-OpenAI": elif factory == "Azure-OpenAI":
llm_name = req["llm_name"]
api_key = apikey_json(["api_key", "api_version"]) api_key = apikey_json(["api_key", "api_version"])


else:
llm_name = req["llm_name"]
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")

llm = { llm = {
"tenant_id": current_user.id, "tenant_id": current_user.id,
"llm_factory": factory, "llm_factory": factory,

正在加载...
取消
保存