Przeglądaj źródła

chore(provider_manager): Update hosted model's name (#14334)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/1.0.0
-LAN- 8 miesięcy temu
rodzic
commit
76bcdc2581
No account linked to committer's email address
2 zmienionych plików z 12 dodań i 15 usunięć
  1. 6
    13
      api/core/hosting_configuration.py
  2. 6
    2
      api/core/provider_manager.py

+ 6
- 13
api/core/hosting_configuration.py Wyświetl plik

if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
return return


self.provider_map["azure_openai"] = self.init_azure_openai()
self.provider_map["openai"] = self.init_openai()
self.provider_map["anthropic"] = self.init_anthropic()
self.provider_map["minimax"] = self.init_minimax()
self.provider_map["spark"] = self.init_spark()
self.provider_map["zhipuai"] = self.init_zhipuai()
# NOTE: We need to use the new name format after the data migration.
# self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai()
# self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai()
# self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic()
# self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
# self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
# self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()


self.moderation_config = self.init_moderation_config() self.moderation_config = self.init_moderation_config()



+ 6
- 2
api/core/provider_manager.py Wyświetl plik

from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError


from configs import dify_config from configs import dify_config
from core.entities import DEFAULT_PLUGIN_ID
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import ( from core.entities.provider_entities import (
if quota.quota_type == ProviderQuotaType.TRIAL: if quota.quota_type == ProviderQuotaType.TRIAL:
# Init trial provider records if not exists # Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
if not provider_name.startswith(DEFAULT_PLUGIN_ID):
continue
hosting_provider_name = provider_name.split("/")[-1]
try: try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
provider_record = Provider( provider_record = Provider(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_name=provider_name,
provider_name=hosting_provider_name,
provider_type=ProviderType.SYSTEM.value, provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value, quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=quota.quota_limit, # type: ignore quota_limit=quota.quota_limit, # type: ignore
db.session.query(Provider) db.session.query(Provider)
.filter( .filter(
Provider.tenant_id == tenant_id, Provider.tenant_id == tenant_id,
Provider.provider_name == provider_name,
Provider.provider_name == hosting_provider_name,
Provider.provider_type == ProviderType.SYSTEM.value, Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value, Provider.quota_type == ProviderQuotaType.TRIAL.value,
) )

Ładowanie…
Anuluj
Zapisz