浏览代码

feat: add `supported_model_types` field and filter in provider list (#1581)

tags/0.3.31
takatost 1年前
父节点
当前提交
c9368925a3
没有帐户链接到提交者的电子邮件

+ 5
- 1
api/controllers/console/workspace/model_providers.py 查看文件

def get(self): def get(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id


parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=False, nullable=True, location='args')
args = parser.parse_args()

provider_service = ProviderService() provider_service = ProviderService()
provider_list = provider_service.get_provider_list(tenant_id)
provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type'))


return provider_list return provider_list



+ 3
- 0
api/core/model_providers/rules/anthropic.json 查看文件

"quota_limit": 0 "quota_limit": 0
}, },
"model_flexibility": "fixed", "model_flexibility": "fixed",
"supported_model_types": [
"text-generation"
],
"price_config": { "price_config": {
"claude-instant-1": { "claude-instant-1": {
"prompt": "1.63", "prompt": "1.63",

+ 4
- 0
api/core/model_providers/rules/azure_openai.json 查看文件

], ],
"system_config": null, "system_config": null,
"model_flexibility": "configurable", "model_flexibility": "configurable",
"supported_model_types": [
"text-generation",
"embeddings"
],
"price_config":{ "price_config":{
"gpt-4": { "gpt-4": {
"prompt": "0.03", "prompt": "0.03",

+ 3
- 0
api/core/model_providers/rules/baichuan.json 查看文件

], ],
"system_config": null, "system_config": null,
"model_flexibility": "fixed", "model_flexibility": "fixed",
"supported_model_types": [
"text-generation"
],
"price_config": { "price_config": {
"baichuan2-53b": { "baichuan2-53b": {
"prompt": "0.01", "prompt": "0.01",

+ 4
- 1
api/core/model_providers/rules/chatglm.json 查看文件

"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"supported_model_types": [
"text-generation"
]
} }

+ 4
- 1
api/core/model_providers/rules/cohere.json 查看文件

"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"supported_model_types": [
"reranking"
]
} }

+ 5
- 1
api/core/model_providers/rules/huggingface_hub.json 查看文件

"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "configurable"
"model_flexibility": "configurable",
"supported_model_types": [
"text-generation",
"embeddings"
]
} }

+ 5
- 1
api/core/model_providers/rules/localai.json 查看文件

"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "configurable"
"model_flexibility": "configurable",
"supported_model_types": [
"text-generation",
"embeddings"
]
} }

+ 4
- 0
api/core/model_providers/rules/minimax.json 查看文件

"quota_unit": "tokens" "quota_unit": "tokens"
}, },
"model_flexibility": "fixed", "model_flexibility": "fixed",
"supported_model_types": [
"text-generation",
"embeddings"
],
"price_config": { "price_config": {
"abab5.5-chat": { "abab5.5-chat": {
"prompt": "0.015", "prompt": "0.015",

+ 6
- 0
api/core/model_providers/rules/openai.json 查看文件

"quota_limit": 200 "quota_limit": 200
}, },
"model_flexibility": "fixed", "model_flexibility": "fixed",
"supported_model_types": [
"text-generation",
"embeddings",
"speech2text",
"moderation"
],
"price_config": { "price_config": {
"gpt-4": { "gpt-4": {
"prompt": "0.03", "prompt": "0.03",

+ 5
- 1
api/core/model_providers/rules/openllm.json 查看文件

"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "configurable"
"model_flexibility": "configurable",
"supported_model_types": [
"text-generation",
"embeddings"
]
} }

+ 5
- 1
api/core/model_providers/rules/replicate.json 查看文件

"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "configurable"
"model_flexibility": "configurable",
"supported_model_types": [
"text-generation",
"embeddings"
]
} }

+ 3
- 0
api/core/model_providers/rules/spark.json 查看文件

"quota_unit": "tokens" "quota_unit": "tokens"
}, },
"model_flexibility": "fixed", "model_flexibility": "fixed",
"supported_model_types": [
"text-generation"
],
"price_config": { "price_config": {
"spark": { "spark": {
"prompt": "0.18", "prompt": "0.18",

+ 3
- 0
api/core/model_providers/rules/tongyi.json 查看文件

], ],
"system_config": null, "system_config": null,
"model_flexibility": "fixed", "model_flexibility": "fixed",
"supported_model_types": [
"text-generation"
],
"price_config": { "price_config": {
"qwen-turbo": { "qwen-turbo": {
"prompt": "0.012", "prompt": "0.012",

+ 3
- 0
api/core/model_providers/rules/wenxin.json 查看文件

], ],
"system_config": null, "system_config": null,
"model_flexibility": "fixed", "model_flexibility": "fixed",
"supported_model_types": [
"text-generation"
],
"price_config": { "price_config": {
"ernie-bot-4": { "ernie-bot-4": {
"prompt": "0", "prompt": "0",

+ 5
- 1
api/core/model_providers/rules/xinference.json 查看文件

"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "configurable"
"model_flexibility": "configurable",
"supported_model_types": [
"text-generation",
"embeddings"
]
} }

+ 4
- 0
api/core/model_providers/rules/zhipuai.json 查看文件

"quota_unit": "tokens" "quota_unit": "tokens"
}, },
"model_flexibility": "fixed", "model_flexibility": "fixed",
"supported_model_types": [
"text-generation",
"embeddings"
],
"price_config": { "price_config": {
"chatglm_turbo": { "chatglm_turbo": {
"prompt": "0.005", "prompt": "0.005",

+ 7
- 2
api/services/provider_service.py 查看文件



class ProviderService: class ProviderService:


def get_provider_list(self, tenant_id: str):
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list:
""" """
get provider list of tenant. get provider list of tenant.


:param tenant_id:
:param tenant_id: workspace id
:param model_type: filter by model type
:return: :return:
""" """
# get rules for all providers # get rules for all providers
providers_list = {} providers_list = {}


for model_provider_name, model_provider_rule in model_provider_rules.items(): for model_provider_name, model_provider_rule in model_provider_rules.items():
if model_type and model_type not in model_provider_rule.get('supported_model_types', []):
continue

# get preferred provider type # get preferred provider type
preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name) preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)
preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider( preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(
provider_config_dict = { provider_config_dict = {
"preferred_provider_type": preferred_provider_type, "preferred_provider_type": preferred_provider_type,
"model_flexibility": model_provider_rule['model_flexibility'], "model_flexibility": model_provider_rule['model_flexibility'],
"supported_model_types": model_provider_rule.get("supported_model_types", []),
} }


provider_parameter_dict = {} provider_parameter_dict = {}

正在加载...
取消
保存