Explorar el Código

refactor: optimize provider configuration queries with provider name … (#15491)

tags/1.0.1
Yeuoly hace 7 meses
padre
commit
a6bc642721
No account linked to committer's email address
Se han modificado 1 ficheros con 107 adiciones y 93 borrados
  1. 107
    93
      api/core/entities/provider_configuration.py

+ 107
- 93
api/core/entities/provider_configuration.py Ver fichero

@@ -7,7 +7,6 @@ from json import JSONDecodeError
from typing import Optional

from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import or_

from constants import HIDDEN_VALUE
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
@@ -180,37 +179,35 @@ class ProviderConfiguration(BaseModel):
else [],
)

def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
def _get_custom_provider_credentials(self) -> Provider | None:
"""
Validate custom credentials.
:param credentials: provider credentials
:return:
Get custom provider credentials.
"""
# get provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
or_(
Provider.provider_name == model_provider_id.provider_name,
Provider.provider_name == self.provider.provider,
),
)
.first()
)
else:
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name == self.provider.provider,
)
.first()
provider_names.append(model_provider_id.provider_name)

provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names),
)
.first()
)

return provider_record

def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
:return:
"""
provider_record = self._get_custom_provider_credentials()

# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
@@ -291,18 +288,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
or_(
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
Provider.provider_name == self.provider.provider,
),
Provider.provider_type == ProviderType.CUSTOM.value,
)
.first()
)
provider_record = self._get_custom_provider_credentials()

# delete provider
if provider_record:
@@ -349,29 +335,47 @@ class ProviderConfiguration(BaseModel):

return None

def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
) -> tuple[ProviderModel | None, dict]:
def _get_custom_model_credentials(
self,
model_type: ModelType,
model: str,
) -> ProviderModel | None:
"""
Validate custom model credentials.

:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
Get custom model credentials.
"""
# get provider model
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

provider_model_record = (
db.session.query(ProviderModel)
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
.first()
)

return provider_model_record

def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
) -> tuple[ProviderModel | None, dict]:
"""
Validate custom model credentials.

:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# get provider model
provider_model_record = self._get_custom_model_credentials(model_type, model)

# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
@@ -451,16 +455,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider model
provider_model_record = (
db.session.query(ProviderModel)
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
.first()
)
provider_model_record = self._get_custom_model_credentials(model_type, model)

# delete provider model
if provider_model_record:
@@ -475,24 +470,35 @@ class ProviderConfiguration(BaseModel):

provider_model_credentials_cache.delete()

def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
"""
Enable model.
:param model_type: model type
:param model: model name
:return:
Get provider model setting.
"""
model_setting = (
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

return (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)

def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Enable model.
:param model_type: model type
:param model: model name
:return:
"""
model_setting = self._get_provider_model_setting(model_type, model)

if model_setting:
model_setting.enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
@@ -516,16 +522,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
model_setting = self._get_provider_model_setting(model_type, model)

if model_setting:
model_setting.enabled = False
@@ -550,13 +547,24 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
return self._get_provider_model_setting(model_type, model)

def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]:
"""
Get load balancing config.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

return (
db.session.query(ProviderModelSetting)
db.session.query(LoadBalancingModelConfig)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.first()
)
@@ -568,11 +576,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
@@ -582,16 +595,7 @@ class ProviderConfiguration(BaseModel):
if load_balancing_config_count <= 1:
raise ValueError("Model load balancing configuration must be more than 1.")

model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
model_setting = self._get_provider_model_setting(model_type, model)

if model_setting:
model_setting.load_balancing_enabled = True
@@ -616,11 +620,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
@@ -677,11 +686,16 @@ class ProviderConfiguration(BaseModel):
return

# get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

preferred_model_provider = (
db.session.query(TenantPreferredModelProvider)
.filter(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name == self.provider.provider,
TenantPreferredModelProvider.provider_name.in_(provider_names),
)
.first()
)

Cargando…
Cancelar
Guardar