|
|
|
@@ -205,16 +205,10 @@ class ProviderConfiguration(BaseModel): |
|
|
|
""" |
|
|
|
Get custom provider record. |
|
|
|
""" |
|
|
|
# get provider |
|
|
|
model_provider_id = ModelProviderID(self.provider.provider) |
|
|
|
provider_names = [self.provider.provider] |
|
|
|
if model_provider_id.is_langgenius(): |
|
|
|
provider_names.append(model_provider_id.provider_name) |
|
|
|
|
|
|
|
stmt = select(Provider).where( |
|
|
|
Provider.tenant_id == self.tenant_id, |
|
|
|
Provider.provider_type == ProviderType.CUSTOM.value, |
|
|
|
Provider.provider_name.in_(provider_names), |
|
|
|
Provider.provider_name.in_(self._get_provider_names()), |
|
|
|
) |
|
|
|
|
|
|
|
return session.execute(stmt).scalar_one_or_none() |
|
|
|
@@ -276,7 +270,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
""" |
|
|
|
stmt = select(ProviderCredential.id).where( |
|
|
|
ProviderCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderCredential.provider_name == self.provider.provider, |
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderCredential.credential_name == credential_name, |
|
|
|
) |
|
|
|
if exclude_id: |
|
|
|
@@ -324,7 +318,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
try: |
|
|
|
stmt = select(ProviderCredential).where( |
|
|
|
ProviderCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderCredential.provider_name == self.provider.provider, |
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderCredential.id == credential_id, |
|
|
|
) |
|
|
|
credential_record = s.execute(stmt).scalar_one_or_none() |
|
|
|
@@ -374,7 +368,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
session=session, |
|
|
|
query_factory=lambda: select(ProviderCredential).where( |
|
|
|
ProviderCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderCredential.provider_name == self.provider.provider, |
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
), |
|
|
|
) |
|
|
|
|
|
|
|
@@ -387,7 +381,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
session=session, |
|
|
|
query_factory=lambda: select(ProviderModelCredential).where( |
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderModelCredential.provider_name == self.provider.provider, |
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderModelCredential.model_name == model, |
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(), |
|
|
|
), |
|
|
|
@@ -423,6 +417,16 @@ class ProviderConfiguration(BaseModel): |
|
|
|
logger.warning("Error generating next credential name: %s", str(e)) |
|
|
|
return "API KEY 1" |
|
|
|
|
|
|
|
def _get_provider_names(self): |
|
|
|
""" |
|
|
|
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`. |
|
|
|
""" |
|
|
|
model_provider_id = ModelProviderID(self.provider.provider) |
|
|
|
provider_names = [self.provider.provider] |
|
|
|
if model_provider_id.is_langgenius(): |
|
|
|
provider_names.append(model_provider_id.provider_name) |
|
|
|
return provider_names |
|
|
|
|
|
|
|
def create_provider_credential(self, credentials: dict, credential_name: str | None): |
|
|
|
""" |
|
|
|
Add custom provider credentials. |
|
|
|
@@ -501,7 +505,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
stmt = select(ProviderCredential).where( |
|
|
|
ProviderCredential.id == credential_id, |
|
|
|
ProviderCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderCredential.provider_name == self.provider.provider, |
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
) |
|
|
|
|
|
|
|
# Get the credential record to update |
|
|
|
@@ -554,7 +558,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
# Find all load balancing configs that use this credential_id |
|
|
|
stmt = select(LoadBalancingModelConfig).where( |
|
|
|
LoadBalancingModelConfig.tenant_id == self.tenant_id, |
|
|
|
LoadBalancingModelConfig.provider_name == self.provider.provider, |
|
|
|
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), |
|
|
|
LoadBalancingModelConfig.credential_id == credential_id, |
|
|
|
LoadBalancingModelConfig.credential_source_type == credential_source, |
|
|
|
) |
|
|
|
@@ -591,7 +595,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
stmt = select(ProviderCredential).where( |
|
|
|
ProviderCredential.id == credential_id, |
|
|
|
ProviderCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderCredential.provider_name == self.provider.provider, |
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
) |
|
|
|
|
|
|
|
# Get the credential record to update |
|
|
|
@@ -602,7 +606,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
# Check if this credential is used in load balancing configs |
|
|
|
lb_stmt = select(LoadBalancingModelConfig).where( |
|
|
|
LoadBalancingModelConfig.tenant_id == self.tenant_id, |
|
|
|
LoadBalancingModelConfig.provider_name == self.provider.provider, |
|
|
|
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), |
|
|
|
LoadBalancingModelConfig.credential_id == credential_id, |
|
|
|
LoadBalancingModelConfig.credential_source_type == "provider", |
|
|
|
) |
|
|
|
@@ -624,7 +628,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
# if this is the last credential, we need to delete the provider record |
|
|
|
count_stmt = select(func.count(ProviderCredential.id)).where( |
|
|
|
ProviderCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderCredential.provider_name == self.provider.provider, |
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
) |
|
|
|
available_credentials_count = session.execute(count_stmt).scalar() or 0 |
|
|
|
session.delete(credential_record) |
|
|
|
@@ -668,7 +672,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
stmt = select(ProviderCredential).where( |
|
|
|
ProviderCredential.id == credential_id, |
|
|
|
ProviderCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderCredential.provider_name == self.provider.provider, |
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
) |
|
|
|
credential_record = session.execute(stmt).scalar_one_or_none() |
|
|
|
if not credential_record: |
|
|
|
@@ -737,7 +741,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
stmt = select(ProviderModelCredential).where( |
|
|
|
ProviderModelCredential.id == credential_id, |
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderModelCredential.provider_name == self.provider.provider, |
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderModelCredential.model_name == model, |
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(), |
|
|
|
) |
|
|
|
@@ -784,7 +788,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
""" |
|
|
|
stmt = select(ProviderModelCredential).where( |
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderModelCredential.provider_name == self.provider.provider, |
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderModelCredential.model_name == model, |
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(), |
|
|
|
ProviderModelCredential.credential_name == credential_name, |
|
|
|
@@ -860,7 +864,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
stmt = select(ProviderModelCredential).where( |
|
|
|
ProviderModelCredential.id == credential_id, |
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderModelCredential.provider_name == self.provider.provider, |
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderModelCredential.model_name == model, |
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(), |
|
|
|
) |
|
|
|
@@ -997,7 +1001,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
stmt = select(ProviderModelCredential).where( |
|
|
|
ProviderModelCredential.id == credential_id, |
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderModelCredential.provider_name == self.provider.provider, |
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderModelCredential.model_name == model, |
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(), |
|
|
|
) |
|
|
|
@@ -1042,7 +1046,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
stmt = select(ProviderModelCredential).where( |
|
|
|
ProviderModelCredential.id == credential_id, |
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderModelCredential.provider_name == self.provider.provider, |
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderModelCredential.model_name == model, |
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(), |
|
|
|
) |
|
|
|
@@ -1052,7 +1056,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
|
|
|
|
lb_stmt = select(LoadBalancingModelConfig).where( |
|
|
|
LoadBalancingModelConfig.tenant_id == self.tenant_id, |
|
|
|
LoadBalancingModelConfig.provider_name == self.provider.provider, |
|
|
|
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), |
|
|
|
LoadBalancingModelConfig.credential_id == credential_id, |
|
|
|
LoadBalancingModelConfig.credential_source_type == "custom_model", |
|
|
|
) |
|
|
|
@@ -1075,7 +1079,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
# if this is the last credential, we need to delete the custom model record |
|
|
|
count_stmt = select(func.count(ProviderModelCredential.id)).where( |
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderModelCredential.provider_name == self.provider.provider, |
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderModelCredential.model_name == model, |
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(), |
|
|
|
) |
|
|
|
@@ -1115,7 +1119,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
stmt = select(ProviderModelCredential).where( |
|
|
|
ProviderModelCredential.id == credential_id, |
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderModelCredential.provider_name == self.provider.provider, |
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderModelCredential.model_name == model, |
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(), |
|
|
|
) |
|
|
|
@@ -1157,7 +1161,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
stmt = select(ProviderModelCredential).where( |
|
|
|
ProviderModelCredential.id == credential_id, |
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id, |
|
|
|
ProviderModelCredential.provider_name == self.provider.provider, |
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderModelCredential.model_name == model, |
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(), |
|
|
|
) |
|
|
|
@@ -1204,15 +1208,9 @@ class ProviderConfiguration(BaseModel): |
|
|
|
""" |
|
|
|
Get provider model setting. |
|
|
|
""" |
|
|
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider) |
|
|
|
provider_names = [self.provider.provider] |
|
|
|
if model_provider_id.is_langgenius(): |
|
|
|
provider_names.append(model_provider_id.provider_name) |
|
|
|
|
|
|
|
stmt = select(ProviderModelSetting).where( |
|
|
|
ProviderModelSetting.tenant_id == self.tenant_id, |
|
|
|
ProviderModelSetting.provider_name.in_(provider_names), |
|
|
|
ProviderModelSetting.provider_name.in_(self._get_provider_names()), |
|
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(), |
|
|
|
ProviderModelSetting.model_name == model, |
|
|
|
) |
|
|
|
@@ -1384,15 +1382,9 @@ class ProviderConfiguration(BaseModel): |
|
|
|
return |
|
|
|
|
|
|
|
def _switch(s: Session): |
|
|
|
# get preferred provider |
|
|
|
model_provider_id = ModelProviderID(self.provider.provider) |
|
|
|
provider_names = [self.provider.provider] |
|
|
|
if model_provider_id.is_langgenius(): |
|
|
|
provider_names.append(model_provider_id.provider_name) |
|
|
|
|
|
|
|
stmt = select(TenantPreferredModelProvider).where( |
|
|
|
TenantPreferredModelProvider.tenant_id == self.tenant_id, |
|
|
|
TenantPreferredModelProvider.provider_name.in_(provider_names), |
|
|
|
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()), |
|
|
|
) |
|
|
|
preferred_model_provider = s.execute(stmt).scalars().first() |
|
|
|
|