Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.0.0
| @@ -52,19 +52,12 @@ class HostingConfiguration: | |||
| if dify_config.EDITION != "CLOUD": | |||
| return | |||
| self.provider_map["azure_openai"] = self.init_azure_openai() | |||
| self.provider_map["openai"] = self.init_openai() | |||
| self.provider_map["anthropic"] = self.init_anthropic() | |||
| self.provider_map["minimax"] = self.init_minimax() | |||
| self.provider_map["spark"] = self.init_spark() | |||
| self.provider_map["zhipuai"] = self.init_zhipuai() | |||
| # NOTE: We need to use the new name format after the data migration. | |||
| # self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai() | |||
| # self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai() | |||
| # self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic() | |||
| # self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax() | |||
| # self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark() | |||
| # self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai() | |||
| self.moderation_config = self.init_moderation_config() | |||
| @@ -6,6 +6,7 @@ from typing import Any, Optional, cast | |||
| from sqlalchemy.exc import IntegrityError | |||
| from configs import dify_config | |||
| from core.entities import DEFAULT_PLUGIN_ID | |||
| from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity | |||
| from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle | |||
| from core.entities.provider_entities import ( | |||
| @@ -504,11 +505,14 @@ class ProviderManager: | |||
| if quota.quota_type == ProviderQuotaType.TRIAL: | |||
| # Init trial provider records if not exists | |||
| if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: | |||
| if not provider_name.startswith(DEFAULT_PLUGIN_ID): | |||
| continue | |||
| hosting_provider_name = provider_name.split("/")[-1] | |||
| try: | |||
| # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic | |||
| provider_record = Provider( | |||
| tenant_id=tenant_id, | |||
| provider_name=provider_name, | |||
| provider_name=hosting_provider_name, | |||
| provider_type=ProviderType.SYSTEM.value, | |||
| quota_type=ProviderQuotaType.TRIAL.value, | |||
| quota_limit=quota.quota_limit, # type: ignore | |||
| @@ -523,7 +527,7 @@ class ProviderManager: | |||
| db.session.query(Provider) | |||
| .filter( | |||
| Provider.tenant_id == tenant_id, | |||
| Provider.provider_name == provider_name, | |||
| Provider.provider_name == hosting_provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == ProviderQuotaType.TRIAL.value, | |||
| ) | |||