浏览代码

refactor add LLM (#2508)

### What problem does this PR solve?

#2487

### Type of change

- [x] Refactoring
tags/v0.12.0
Kevin Hu 1年前
父节点
当前提交
5968f148bc
没有帐户链接到提交者的电子邮件
共有 2 个文件被更改,包括 26 次插入23 次删除
  1. 25
    22
      api/apps/llm_app.py
  2. 1
    1
      rag/llm/chat_model.py

+ 25
- 22
api/apps/llm_app.py 查看文件

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import json

from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
req = request.json req = request.json
factory = req["llm_factory"] factory = req["llm_factory"]


def apikey_json(keys):
nonlocal req
return json.dumps({k: req.get(k, "") for k in keys})

if factory == "VolcEngine": if factory == "VolcEngine":
# For VolcEngine, due to its special authentication method # For VolcEngine, due to its special authentication method
# Assemble ark_api_key endpoint_id into api_key # Assemble ark_api_key endpoint_id into api_key
llm_name = req["llm_name"] llm_name = req["llm_name"]
api_key = f'{{ "ark_api_key":"{req.get("ark_api_key", "")}", "ep_id":"{req.get("endpoint_id", "")}" }}'
api_key = apikey_json(["ark_api_key", "endpoint_id"])

elif factory == "Tencent Hunyuan": elif factory == "Tencent Hunyuan":
api_key = '{' + f'"hunyuan_sid": "{req.get("hunyuan_sid", "")}", ' \
f'"hunyuan_sk": "{req.get("hunyuan_sk", "")}"' + '}'
req["api_key"] = api_key
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
return set_api_key() return set_api_key()

elif factory == "Tencent Cloud": elif factory == "Tencent Cloud":
api_key = '{' + f'"tencent_cloud_sid": "{req.get("tencent_cloud_sid", "")}", ' \
f'"tencent_cloud_sk": "{req.get("tencent_cloud_sk", "")}"' + '}'
req["api_key"] = api_key
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])

elif factory == "Bedrock": elif factory == "Bedrock":
# For Bedrock, due to its special authentication method # For Bedrock, due to its special authentication method
# Assemble bedrock_ak, bedrock_sk, bedrock_region # Assemble bedrock_ak, bedrock_sk, bedrock_region
llm_name = req["llm_name"] llm_name = req["llm_name"]
api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \
f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \
f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])

elif factory == "LocalAI": elif factory == "LocalAI":
llm_name = req["llm_name"]+"___LocalAI" llm_name = req["llm_name"]+"___LocalAI"
api_key = "xxxxxxxxxxxxxxx" api_key = "xxxxxxxxxxxxxxx"

elif factory == "OpenAI-API-Compatible": elif factory == "OpenAI-API-Compatible":
llm_name = req["llm_name"]+"___OpenAI-API" llm_name = req["llm_name"]+"___OpenAI-API"
api_key = req.get("api_key","xxxxxxxxxxxxxxx") api_key = req.get("api_key","xxxxxxxxxxxxxxx")

elif factory =="XunFei Spark": elif factory =="XunFei Spark":
llm_name = req["llm_name"] llm_name = req["llm_name"]
api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")

elif factory == "BaiduYiyan": elif factory == "BaiduYiyan":
llm_name = req["llm_name"] llm_name = req["llm_name"]
api_key = '{' + f'"yiyan_ak": "{req.get("yiyan_ak", "")}", ' \
f'"yiyan_sk": "{req.get("yiyan_sk", "")}"' + '}'
api_key = apikey_json(["yiyan_ak", "yiyan_sk"])
elif factory == "Fish Audio": elif factory == "Fish Audio":
llm_name = req["llm_name"] llm_name = req["llm_name"]
api_key = '{' + f'"fish_audio_ak": "{req.get("fish_audio_ak", "")}", ' \
f'"fish_audio_refid": "{req.get("fish_audio_refid", "59cb5986671546eaa6ca8ae6f29f6d22")}"' + '}'
api_key = apikey_json(["fish_audio_ak", "fish_audio_refid"])
elif factory == "Google Cloud": elif factory == "Google Cloud":
llm_name = req["llm_name"] llm_name = req["llm_name"]
api_key = (
"{" + f'"google_project_id": "{req.get("google_project_id", "")}", '
f'"google_region": "{req.get("google_region", "")}", '
f'"google_service_account_key": "{req.get("google_service_account_key", "")}"'
+ "}"
)
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])

else: else:
llm_name = req["llm_name"] llm_name = req["llm_name"]
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")


llm = { llm = {
"tenant_id": current_user.id, "tenant_id": current_user.id,

+ 1
- 1
rag/llm/chat_model.py 查看文件

""" """
base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3' base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3'
ark_api_key = json.loads(key).get('ark_api_key', '') ark_api_key = json.loads(key).get('ark_api_key', '')
model_name = json.loads(key).get('ep_id', '')
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
super().__init__(ark_api_key, model_name, base_url) super().__init__(ark_api_key, model_name, base_url)





正在加载...
取消
保存