| @@ -7,7 +7,6 @@ from json import JSONDecodeError | |||
| from typing import Optional | |||
| from pydantic import BaseModel, ConfigDict, Field | |||
| from sqlalchemy import or_ | |||
| from constants import HIDDEN_VALUE | |||
| from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity | |||
| @@ -180,37 +179,35 @@ class ProviderConfiguration(BaseModel): | |||
| else [], | |||
| ) | |||
| def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: | |||
| def _get_custom_provider_credentials(self) -> Provider | None: | |||
| """ | |||
| Validate custom credentials. | |||
| :param credentials: provider credentials | |||
| :return: | |||
| Get custom provider credentials. | |||
| """ | |||
| # get provider | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| if model_provider_id.is_langgenius(): | |||
| provider_record = ( | |||
| db.session.query(Provider) | |||
| .filter( | |||
| Provider.tenant_id == self.tenant_id, | |||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||
| or_( | |||
| Provider.provider_name == model_provider_id.provider_name, | |||
| Provider.provider_name == self.provider.provider, | |||
| ), | |||
| ) | |||
| .first() | |||
| ) | |||
| else: | |||
| provider_record = ( | |||
| db.session.query(Provider) | |||
| .filter( | |||
| Provider.tenant_id == self.tenant_id, | |||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||
| Provider.provider_name == self.provider.provider, | |||
| ) | |||
| .first() | |||
| provider_names.append(model_provider_id.provider_name) | |||
| provider_record = ( | |||
| db.session.query(Provider) | |||
| .filter( | |||
| Provider.tenant_id == self.tenant_id, | |||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||
| Provider.provider_name.in_(provider_names), | |||
| ) | |||
| .first() | |||
| ) | |||
| return provider_record | |||
| def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: | |||
| """ | |||
| Validate custom credentials. | |||
| :param credentials: provider credentials | |||
| :return: | |||
| """ | |||
| provider_record = self._get_custom_provider_credentials() | |||
| # Get provider credential secret variables | |||
| provider_credential_secret_variables = self.extract_secret_variables( | |||
| @@ -291,18 +288,7 @@ class ProviderConfiguration(BaseModel): | |||
| :return: | |||
| """ | |||
| # get provider | |||
| provider_record = ( | |||
| db.session.query(Provider) | |||
| .filter( | |||
| Provider.tenant_id == self.tenant_id, | |||
| or_( | |||
| Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name, | |||
| Provider.provider_name == self.provider.provider, | |||
| ), | |||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||
| ) | |||
| .first() | |||
| ) | |||
| provider_record = self._get_custom_provider_credentials() | |||
| # delete provider | |||
| if provider_record: | |||
| @@ -349,29 +335,47 @@ class ProviderConfiguration(BaseModel): | |||
| return None | |||
| def custom_model_credentials_validate( | |||
| self, model_type: ModelType, model: str, credentials: dict | |||
| ) -> tuple[ProviderModel | None, dict]: | |||
| def _get_custom_model_credentials( | |||
| self, | |||
| model_type: ModelType, | |||
| model: str, | |||
| ) -> ProviderModel | None: | |||
| """ | |||
| Validate custom model credentials. | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :return: | |||
| Get custom model credentials. | |||
| """ | |||
| # get provider model | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| if model_provider_id.is_langgenius(): | |||
| provider_names.append(model_provider_id.provider_name) | |||
| provider_model_record = ( | |||
| db.session.query(ProviderModel) | |||
| .filter( | |||
| ProviderModel.tenant_id == self.tenant_id, | |||
| ProviderModel.provider_name == self.provider.provider, | |||
| ProviderModel.provider_name.in_(provider_names), | |||
| ProviderModel.model_name == model, | |||
| ProviderModel.model_type == model_type.to_origin_model_type(), | |||
| ) | |||
| .first() | |||
| ) | |||
| return provider_model_record | |||
| def custom_model_credentials_validate( | |||
| self, model_type: ModelType, model: str, credentials: dict | |||
| ) -> tuple[ProviderModel | None, dict]: | |||
| """ | |||
| Validate custom model credentials. | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :return: | |||
| """ | |||
| # get provider model | |||
| provider_model_record = self._get_custom_model_credentials(model_type, model) | |||
| # Get provider credential secret variables | |||
| provider_credential_secret_variables = self.extract_secret_variables( | |||
| self.provider.model_credential_schema.credential_form_schemas | |||
| @@ -451,16 +455,7 @@ class ProviderConfiguration(BaseModel): | |||
| :return: | |||
| """ | |||
| # get provider model | |||
| provider_model_record = ( | |||
| db.session.query(ProviderModel) | |||
| .filter( | |||
| ProviderModel.tenant_id == self.tenant_id, | |||
| ProviderModel.provider_name == self.provider.provider, | |||
| ProviderModel.model_name == model, | |||
| ProviderModel.model_type == model_type.to_origin_model_type(), | |||
| ) | |||
| .first() | |||
| ) | |||
| provider_model_record = self._get_custom_model_credentials(model_type, model) | |||
| # delete provider model | |||
| if provider_model_record: | |||
| @@ -475,24 +470,35 @@ class ProviderConfiguration(BaseModel): | |||
| provider_model_credentials_cache.delete() | |||
| def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: | |||
| def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None: | |||
| """ | |||
| Enable model. | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :return: | |||
| Get provider model setting. | |||
| """ | |||
| model_setting = ( | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| if model_provider_id.is_langgenius(): | |||
| provider_names.append(model_provider_id.provider_name) | |||
| return ( | |||
| db.session.query(ProviderModelSetting) | |||
| .filter( | |||
| ProviderModelSetting.tenant_id == self.tenant_id, | |||
| ProviderModelSetting.provider_name == self.provider.provider, | |||
| ProviderModelSetting.provider_name.in_(provider_names), | |||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |||
| ProviderModelSetting.model_name == model, | |||
| ) | |||
| .first() | |||
| ) | |||
| def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: | |||
| """ | |||
| Enable model. | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| model_setting = self._get_provider_model_setting(model_type, model) | |||
| if model_setting: | |||
| model_setting.enabled = True | |||
| model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| @@ -516,16 +522,7 @@ class ProviderConfiguration(BaseModel): | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| model_setting = ( | |||
| db.session.query(ProviderModelSetting) | |||
| .filter( | |||
| ProviderModelSetting.tenant_id == self.tenant_id, | |||
| ProviderModelSetting.provider_name == self.provider.provider, | |||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |||
| ProviderModelSetting.model_name == model, | |||
| ) | |||
| .first() | |||
| ) | |||
| model_setting = self._get_provider_model_setting(model_type, model) | |||
| if model_setting: | |||
| model_setting.enabled = False | |||
| @@ -550,13 +547,24 @@ class ProviderConfiguration(BaseModel): | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| return self._get_provider_model_setting(model_type, model) | |||
| def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]: | |||
| """ | |||
| Get load balancing config. | |||
| """ | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| if model_provider_id.is_langgenius(): | |||
| provider_names.append(model_provider_id.provider_name) | |||
| return ( | |||
| db.session.query(ProviderModelSetting) | |||
| db.session.query(LoadBalancingModelConfig) | |||
| .filter( | |||
| ProviderModelSetting.tenant_id == self.tenant_id, | |||
| ProviderModelSetting.provider_name == self.provider.provider, | |||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |||
| ProviderModelSetting.model_name == model, | |||
| LoadBalancingModelConfig.tenant_id == self.tenant_id, | |||
| LoadBalancingModelConfig.provider_name.in_(provider_names), | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model, | |||
| ) | |||
| .first() | |||
| ) | |||
| @@ -568,11 +576,16 @@ class ProviderConfiguration(BaseModel): | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| if model_provider_id.is_langgenius(): | |||
| provider_names.append(model_provider_id.provider_name) | |||
| load_balancing_config_count = ( | |||
| db.session.query(LoadBalancingModelConfig) | |||
| .filter( | |||
| LoadBalancingModelConfig.tenant_id == self.tenant_id, | |||
| LoadBalancingModelConfig.provider_name == self.provider.provider, | |||
| LoadBalancingModelConfig.provider_name.in_(provider_names), | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model, | |||
| ) | |||
| @@ -582,16 +595,7 @@ class ProviderConfiguration(BaseModel): | |||
| if load_balancing_config_count <= 1: | |||
| raise ValueError("Model load balancing configuration must be more than 1.") | |||
| model_setting = ( | |||
| db.session.query(ProviderModelSetting) | |||
| .filter( | |||
| ProviderModelSetting.tenant_id == self.tenant_id, | |||
| ProviderModelSetting.provider_name == self.provider.provider, | |||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |||
| ProviderModelSetting.model_name == model, | |||
| ) | |||
| .first() | |||
| ) | |||
| model_setting = self._get_provider_model_setting(model_type, model) | |||
| if model_setting: | |||
| model_setting.load_balancing_enabled = True | |||
| @@ -616,11 +620,16 @@ class ProviderConfiguration(BaseModel): | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| if model_provider_id.is_langgenius(): | |||
| provider_names.append(model_provider_id.provider_name) | |||
| model_setting = ( | |||
| db.session.query(ProviderModelSetting) | |||
| .filter( | |||
| ProviderModelSetting.tenant_id == self.tenant_id, | |||
| ProviderModelSetting.provider_name == self.provider.provider, | |||
| ProviderModelSetting.provider_name.in_(provider_names), | |||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |||
| ProviderModelSetting.model_name == model, | |||
| ) | |||
| @@ -677,11 +686,16 @@ class ProviderConfiguration(BaseModel): | |||
| return | |||
| # get preferred provider | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| if model_provider_id.is_langgenius(): | |||
| provider_names.append(model_provider_id.provider_name) | |||
| preferred_model_provider = ( | |||
| db.session.query(TenantPreferredModelProvider) | |||
| .filter( | |||
| TenantPreferredModelProvider.tenant_id == self.tenant_id, | |||
| TenantPreferredModelProvider.provider_name == self.provider.provider, | |||
| TenantPreferredModelProvider.provider_name.in_(provider_names), | |||
| ) | |||
| .first() | |||
| ) | |||