| from typing import Optional | from typing import Optional | ||||
| from pydantic import BaseModel, ConfigDict, Field | from pydantic import BaseModel, ConfigDict, Field | ||||
| from sqlalchemy import or_ | |||||
| from constants import HIDDEN_VALUE | from constants import HIDDEN_VALUE | ||||
| from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity | from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity | ||||
| else [], | else [], | ||||
| ) | ) | ||||
| def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: | |||||
| def _get_custom_provider_credentials(self) -> Provider | None: | |||||
| """ | """ | ||||
| Validate custom credentials. | |||||
| :param credentials: provider credentials | |||||
| :return: | |||||
| Get custom provider credentials. | |||||
| """ | """ | ||||
| # get provider | # get provider | ||||
| model_provider_id = ModelProviderID(self.provider.provider) | model_provider_id = ModelProviderID(self.provider.provider) | ||||
| provider_names = [self.provider.provider] | |||||
| if model_provider_id.is_langgenius(): | if model_provider_id.is_langgenius(): | ||||
| provider_record = ( | |||||
| db.session.query(Provider) | |||||
| .filter( | |||||
| Provider.tenant_id == self.tenant_id, | |||||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||||
| or_( | |||||
| Provider.provider_name == model_provider_id.provider_name, | |||||
| Provider.provider_name == self.provider.provider, | |||||
| ), | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| else: | |||||
| provider_record = ( | |||||
| db.session.query(Provider) | |||||
| .filter( | |||||
| Provider.tenant_id == self.tenant_id, | |||||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||||
| Provider.provider_name == self.provider.provider, | |||||
| ) | |||||
| .first() | |||||
| provider_names.append(model_provider_id.provider_name) | |||||
| provider_record = ( | |||||
| db.session.query(Provider) | |||||
| .filter( | |||||
| Provider.tenant_id == self.tenant_id, | |||||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||||
| Provider.provider_name.in_(provider_names), | |||||
| ) | ) | ||||
| .first() | |||||
| ) | |||||
| return provider_record | |||||
| def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: | |||||
| """ | |||||
| Validate custom credentials. | |||||
| :param credentials: provider credentials | |||||
| :return: | |||||
| """ | |||||
| provider_record = self._get_custom_provider_credentials() | |||||
| # Get provider credential secret variables | # Get provider credential secret variables | ||||
| provider_credential_secret_variables = self.extract_secret_variables( | provider_credential_secret_variables = self.extract_secret_variables( | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # get provider | # get provider | ||||
| provider_record = ( | |||||
| db.session.query(Provider) | |||||
| .filter( | |||||
| Provider.tenant_id == self.tenant_id, | |||||
| or_( | |||||
| Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name, | |||||
| Provider.provider_name == self.provider.provider, | |||||
| ), | |||||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| provider_record = self._get_custom_provider_credentials() | |||||
| # delete provider | # delete provider | ||||
| if provider_record: | if provider_record: | ||||
| return None | return None | ||||
| def custom_model_credentials_validate( | |||||
| self, model_type: ModelType, model: str, credentials: dict | |||||
| ) -> tuple[ProviderModel | None, dict]: | |||||
| def _get_custom_model_credentials( | |||||
| self, | |||||
| model_type: ModelType, | |||||
| model: str, | |||||
| ) -> ProviderModel | None: | |||||
| """ | """ | ||||
| Validate custom model credentials. | |||||
| :param model_type: model type | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :return: | |||||
| Get custom model credentials. | |||||
| """ | """ | ||||
| # get provider model | # get provider model | ||||
| model_provider_id = ModelProviderID(self.provider.provider) | |||||
| provider_names = [self.provider.provider] | |||||
| if model_provider_id.is_langgenius(): | |||||
| provider_names.append(model_provider_id.provider_name) | |||||
| provider_model_record = ( | provider_model_record = ( | ||||
| db.session.query(ProviderModel) | db.session.query(ProviderModel) | ||||
| .filter( | .filter( | ||||
| ProviderModel.tenant_id == self.tenant_id, | ProviderModel.tenant_id == self.tenant_id, | ||||
| ProviderModel.provider_name == self.provider.provider, | |||||
| ProviderModel.provider_name.in_(provider_names), | |||||
| ProviderModel.model_name == model, | ProviderModel.model_name == model, | ||||
| ProviderModel.model_type == model_type.to_origin_model_type(), | ProviderModel.model_type == model_type.to_origin_model_type(), | ||||
| ) | ) | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| return provider_model_record | |||||
| def custom_model_credentials_validate( | |||||
| self, model_type: ModelType, model: str, credentials: dict | |||||
| ) -> tuple[ProviderModel | None, dict]: | |||||
| """ | |||||
| Validate custom model credentials. | |||||
| :param model_type: model type | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :return: | |||||
| """ | |||||
| # get provider model | |||||
| provider_model_record = self._get_custom_model_credentials(model_type, model) | |||||
| # Get provider credential secret variables | # Get provider credential secret variables | ||||
| provider_credential_secret_variables = self.extract_secret_variables( | provider_credential_secret_variables = self.extract_secret_variables( | ||||
| self.provider.model_credential_schema.credential_form_schemas | self.provider.model_credential_schema.credential_form_schemas | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # get provider model | # get provider model | ||||
| provider_model_record = ( | |||||
| db.session.query(ProviderModel) | |||||
| .filter( | |||||
| ProviderModel.tenant_id == self.tenant_id, | |||||
| ProviderModel.provider_name == self.provider.provider, | |||||
| ProviderModel.model_name == model, | |||||
| ProviderModel.model_type == model_type.to_origin_model_type(), | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| provider_model_record = self._get_custom_model_credentials(model_type, model) | |||||
| # delete provider model | # delete provider model | ||||
| if provider_model_record: | if provider_model_record: | ||||
| provider_model_credentials_cache.delete() | provider_model_credentials_cache.delete() | ||||
| def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: | |||||
| def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None: | |||||
| """ | """ | ||||
| Enable model. | |||||
| :param model_type: model type | |||||
| :param model: model name | |||||
| :return: | |||||
| Get provider model setting. | |||||
| """ | """ | ||||
| model_setting = ( | |||||
| model_provider_id = ModelProviderID(self.provider.provider) | |||||
| provider_names = [self.provider.provider] | |||||
| if model_provider_id.is_langgenius(): | |||||
| provider_names.append(model_provider_id.provider_name) | |||||
| return ( | |||||
| db.session.query(ProviderModelSetting) | db.session.query(ProviderModelSetting) | ||||
| .filter( | .filter( | ||||
| ProviderModelSetting.tenant_id == self.tenant_id, | ProviderModelSetting.tenant_id == self.tenant_id, | ||||
| ProviderModelSetting.provider_name == self.provider.provider, | |||||
| ProviderModelSetting.provider_name.in_(provider_names), | |||||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | ProviderModelSetting.model_type == model_type.to_origin_model_type(), | ||||
| ProviderModelSetting.model_name == model, | ProviderModelSetting.model_name == model, | ||||
| ) | ) | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: | |||||
| """ | |||||
| Enable model. | |||||
| :param model_type: model type | |||||
| :param model: model name | |||||
| :return: | |||||
| """ | |||||
| model_setting = self._get_provider_model_setting(model_type, model) | |||||
| if model_setting: | if model_setting: | ||||
| model_setting.enabled = True | model_setting.enabled = True | ||||
| model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | ||||
| :param model: model name | :param model: model name | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| model_setting = ( | |||||
| db.session.query(ProviderModelSetting) | |||||
| .filter( | |||||
| ProviderModelSetting.tenant_id == self.tenant_id, | |||||
| ProviderModelSetting.provider_name == self.provider.provider, | |||||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |||||
| ProviderModelSetting.model_name == model, | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| model_setting = self._get_provider_model_setting(model_type, model) | |||||
| if model_setting: | if model_setting: | ||||
| model_setting.enabled = False | model_setting.enabled = False | ||||
| :param model: model name | :param model: model name | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| return self._get_provider_model_setting(model_type, model) | |||||
| def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]: | |||||
| """ | |||||
| Get load balancing config. | |||||
| """ | |||||
| model_provider_id = ModelProviderID(self.provider.provider) | |||||
| provider_names = [self.provider.provider] | |||||
| if model_provider_id.is_langgenius(): | |||||
| provider_names.append(model_provider_id.provider_name) | |||||
| return ( | return ( | ||||
| db.session.query(ProviderModelSetting) | |||||
| db.session.query(LoadBalancingModelConfig) | |||||
| .filter( | .filter( | ||||
| ProviderModelSetting.tenant_id == self.tenant_id, | |||||
| ProviderModelSetting.provider_name == self.provider.provider, | |||||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |||||
| ProviderModelSetting.model_name == model, | |||||
| LoadBalancingModelConfig.tenant_id == self.tenant_id, | |||||
| LoadBalancingModelConfig.provider_name.in_(provider_names), | |||||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||||
| LoadBalancingModelConfig.model_name == model, | |||||
| ) | ) | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| :param model: model name | :param model: model name | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| model_provider_id = ModelProviderID(self.provider.provider) | |||||
| provider_names = [self.provider.provider] | |||||
| if model_provider_id.is_langgenius(): | |||||
| provider_names.append(model_provider_id.provider_name) | |||||
| load_balancing_config_count = ( | load_balancing_config_count = ( | ||||
| db.session.query(LoadBalancingModelConfig) | db.session.query(LoadBalancingModelConfig) | ||||
| .filter( | .filter( | ||||
| LoadBalancingModelConfig.tenant_id == self.tenant_id, | LoadBalancingModelConfig.tenant_id == self.tenant_id, | ||||
| LoadBalancingModelConfig.provider_name == self.provider.provider, | |||||
| LoadBalancingModelConfig.provider_name.in_(provider_names), | |||||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | ||||
| LoadBalancingModelConfig.model_name == model, | LoadBalancingModelConfig.model_name == model, | ||||
| ) | ) | ||||
| if load_balancing_config_count <= 1: | if load_balancing_config_count <= 1: | ||||
| raise ValueError("Model load balancing configuration must be more than 1.") | raise ValueError("Model load balancing configuration must be more than 1.") | ||||
| model_setting = ( | |||||
| db.session.query(ProviderModelSetting) | |||||
| .filter( | |||||
| ProviderModelSetting.tenant_id == self.tenant_id, | |||||
| ProviderModelSetting.provider_name == self.provider.provider, | |||||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |||||
| ProviderModelSetting.model_name == model, | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| model_setting = self._get_provider_model_setting(model_type, model) | |||||
| if model_setting: | if model_setting: | ||||
| model_setting.load_balancing_enabled = True | model_setting.load_balancing_enabled = True | ||||
| :param model: model name | :param model: model name | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| model_provider_id = ModelProviderID(self.provider.provider) | |||||
| provider_names = [self.provider.provider] | |||||
| if model_provider_id.is_langgenius(): | |||||
| provider_names.append(model_provider_id.provider_name) | |||||
| model_setting = ( | model_setting = ( | ||||
| db.session.query(ProviderModelSetting) | db.session.query(ProviderModelSetting) | ||||
| .filter( | .filter( | ||||
| ProviderModelSetting.tenant_id == self.tenant_id, | ProviderModelSetting.tenant_id == self.tenant_id, | ||||
| ProviderModelSetting.provider_name == self.provider.provider, | |||||
| ProviderModelSetting.provider_name.in_(provider_names), | |||||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | ProviderModelSetting.model_type == model_type.to_origin_model_type(), | ||||
| ProviderModelSetting.model_name == model, | ProviderModelSetting.model_name == model, | ||||
| ) | ) | ||||
| return | return | ||||
| # get preferred provider | # get preferred provider | ||||
| model_provider_id = ModelProviderID(self.provider.provider) | |||||
| provider_names = [self.provider.provider] | |||||
| if model_provider_id.is_langgenius(): | |||||
| provider_names.append(model_provider_id.provider_name) | |||||
| preferred_model_provider = ( | preferred_model_provider = ( | ||||
| db.session.query(TenantPreferredModelProvider) | db.session.query(TenantPreferredModelProvider) | ||||
| .filter( | .filter( | ||||
| TenantPreferredModelProvider.tenant_id == self.tenant_id, | TenantPreferredModelProvider.tenant_id == self.tenant_id, | ||||
| TenantPreferredModelProvider.provider_name == self.provider.provider, | |||||
| TenantPreferredModelProvider.provider_name.in_(provider_names), | |||||
| ) | ) | ||||
| .first() | .first() | ||||
| ) | ) |