Browse Source

refactor: optimize provider configuration queries with provider name … (#15491)

tags/1.0.1
Yeuoly 7 months ago
parent
commit
a6bc642721
No account linked to committer's email address
1 changed files with 107 additions and 93 deletions
  1. 107
    93
      api/core/entities/provider_configuration.py

+ 107
- 93
api/core/entities/provider_configuration.py View File

from typing import Optional from typing import Optional


from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import or_


from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
else [], else [],
) )


def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
def _get_custom_provider_credentials(self) -> Provider | None:
""" """
Validate custom credentials.
:param credentials: provider credentials
:return:
Get custom provider credentials.
""" """
# get provider # get provider
model_provider_id = ModelProviderID(self.provider.provider) model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius(): if model_provider_id.is_langgenius():
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
or_(
Provider.provider_name == model_provider_id.provider_name,
Provider.provider_name == self.provider.provider,
),
)
.first()
)
else:
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name == self.provider.provider,
)
.first()
provider_names.append(model_provider_id.provider_name)

provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names),
) )
.first()
)

return provider_record

def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
:return:
"""
provider_record = self._get_custom_provider_credentials()


# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
:return: :return:
""" """
# get provider # get provider
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
or_(
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
Provider.provider_name == self.provider.provider,
),
Provider.provider_type == ProviderType.CUSTOM.value,
)
.first()
)
provider_record = self._get_custom_provider_credentials()


# delete provider # delete provider
if provider_record: if provider_record:


return None return None


def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
) -> tuple[ProviderModel | None, dict]:
def _get_custom_model_credentials(
self,
model_type: ModelType,
model: str,
) -> ProviderModel | None:
""" """
Validate custom model credentials.

:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
Get custom model credentials.
""" """
# get provider model # get provider model
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

provider_model_record = ( provider_model_record = (
db.session.query(ProviderModel) db.session.query(ProviderModel)
.filter( .filter(
ProviderModel.tenant_id == self.tenant_id, ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model, ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(), ProviderModel.model_type == model_type.to_origin_model_type(),
) )
.first() .first()
) )


return provider_model_record

def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
) -> tuple[ProviderModel | None, dict]:
"""
Validate custom model credentials.

:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# get provider model
provider_model_record = self._get_custom_model_credentials(model_type, model)

# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas self.provider.model_credential_schema.credential_form_schemas
:return: :return:
""" """
# get provider model # get provider model
provider_model_record = (
db.session.query(ProviderModel)
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
.first()
)
provider_model_record = self._get_custom_model_credentials(model_type, model)


# delete provider model # delete provider model
if provider_model_record: if provider_model_record:


provider_model_credentials_cache.delete() provider_model_credentials_cache.delete()


def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
""" """
Enable model.
:param model_type: model type
:param model: model name
:return:
Get provider model setting.
""" """
model_setting = (
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

return (
db.session.query(ProviderModelSetting) db.session.query(ProviderModelSetting)
.filter( .filter(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model, ProviderModelSetting.model_name == model,
) )
.first() .first()
) )


def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Enable model.
:param model_type: model type
:param model: model name
:return:
"""
model_setting = self._get_provider_model_setting(model_type, model)

if model_setting: if model_setting:
model_setting.enabled = True model_setting.enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
:param model: model name :param model: model name
:return: :return:
""" """
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
model_setting = self._get_provider_model_setting(model_type, model)


if model_setting: if model_setting:
model_setting.enabled = False model_setting.enabled = False
:param model: model name :param model: model name
:return: :return:
""" """
return self._get_provider_model_setting(model_type, model)

def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]:
"""
Get load balancing config.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

return ( return (
db.session.query(ProviderModelSetting)
db.session.query(LoadBalancingModelConfig)
.filter( .filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
) )
.first() .first()
) )
:param model: model name :param model: model name
:return: :return:
""" """
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

load_balancing_config_count = ( load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig) db.session.query(LoadBalancingModelConfig)
.filter( .filter(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.model_name == model,
) )
if load_balancing_config_count <= 1: if load_balancing_config_count <= 1:
raise ValueError("Model load balancing configuration must be more than 1.") raise ValueError("Model load balancing configuration must be more than 1.")


model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
model_setting = self._get_provider_model_setting(model_type, model)


if model_setting: if model_setting:
model_setting.load_balancing_enabled = True model_setting.load_balancing_enabled = True
:param model: model name :param model: model name
:return: :return:
""" """
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

model_setting = ( model_setting = (
db.session.query(ProviderModelSetting) db.session.query(ProviderModelSetting)
.filter( .filter(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model, ProviderModelSetting.model_name == model,
) )
return return


# get preferred provider # get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

preferred_model_provider = ( preferred_model_provider = (
db.session.query(TenantPreferredModelProvider) db.session.query(TenantPreferredModelProvider)
.filter( .filter(
TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name == self.provider.provider,
TenantPreferredModelProvider.provider_name.in_(provider_names),
) )
.first() .first()
) )

Loading…
Cancel
Save