소스 검색

fix: Ensure compatibility with old provider name when updating model credentials (#26017)

tags/1.9.0
非法操作 1 개월 전
부모
커밋
ef80d3b707
No account linked to committer's email address
2개의 변경된 파일52개의 추가작업 그리고 42개의 파일을 삭제
  1. 32
    40
      api/core/entities/provider_configuration.py
  2. 20
    2
      api/core/provider_manager.py

+ 32
- 40
api/core/entities/provider_configuration.py 파일 보기

""" """
Get custom provider record. Get custom provider record.
""" """
# get provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

stmt = select(Provider).where( stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id, Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value, Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names),
Provider.provider_name.in_(self._get_provider_names()),
) )


return session.execute(stmt).scalar_one_or_none() return session.execute(stmt).scalar_one_or_none()
""" """
stmt = select(ProviderCredential.id).where( stmt = select(ProviderCredential.id).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.credential_name == credential_name, ProviderCredential.credential_name == credential_name,
) )
if exclude_id: if exclude_id:
try: try:
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
) )
credential_record = s.execute(stmt).scalar_one_or_none() credential_record = s.execute(stmt).scalar_one_or_none()
session=session, session=session,
query_factory=lambda: select(ProviderCredential).where( query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
), ),
) )


session=session, session=session,
query_factory=lambda: select(ProviderModelCredential).where( query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
), ),
logger.warning("Error generating next credential name: %s", str(e)) logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1" return "API KEY 1"


def _get_provider_names(self):
"""
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
return provider_names

def create_provider_credential(self, credentials: dict, credential_name: str | None): def create_provider_credential(self, credentials: dict, credential_name: str | None):
""" """
Add custom provider credentials. Add custom provider credentials.
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
) )


# Get the credential record to update # Get the credential record to update
# Find all load balancing configs that use this credential_id # Find all load balancing configs that use this credential_id
stmt = select(LoadBalancingModelConfig).where( stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == credential_source, LoadBalancingModelConfig.credential_source_type == credential_source,
) )
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
) )


# Get the credential record to update # Get the credential record to update
# Check if this credential is used in load balancing configs # Check if this credential is used in load balancing configs
lb_stmt = select(LoadBalancingModelConfig).where( lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider", LoadBalancingModelConfig.credential_source_type == "provider",
) )
# if this is the last credential, we need to delete the provider record # if this is the last credential, we need to delete the provider record
count_stmt = select(func.count(ProviderCredential.id)).where( count_stmt = select(func.count(ProviderCredential.id)).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
) )
available_credentials_count = session.execute(count_stmt).scalar() or 0 available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record) session.delete(credential_record)
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
) )
credential_record = session.execute(stmt).scalar_one_or_none() credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record: if not credential_record:
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
""" """
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.credential_name == credential_name, ProviderModelCredential.credential_name == credential_name,
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )


lb_stmt = select(LoadBalancingModelConfig).where( lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model", LoadBalancingModelConfig.credential_source_type == "custom_model",
) )
# if this is the last credential, we need to delete the custom model record # if this is the last credential, we need to delete the custom model record
count_stmt = select(func.count(ProviderModelCredential.id)).where( count_stmt = select(func.count(ProviderModelCredential.id)).where(
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
""" """
Get provider model setting. Get provider model setting.
""" """

model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

stmt = select(ProviderModelSetting).where( stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model, ProviderModelSetting.model_name == model,
) )
return return


def _switch(s: Session): def _switch(s: Session):
# get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)

stmt = select(TenantPreferredModelProvider).where( stmt = select(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names),
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
) )
preferred_model_provider = s.execute(stmt).scalars().first() preferred_model_provider = s.execute(stmt).scalars().first()



+ 20
- 2
api/core/provider_manager.py 파일 보기



return provider_name_to_provider_load_balancing_model_configs_dict return provider_name_to_provider_load_balancing_model_configs_dict


@staticmethod
def _get_provider_names(provider_name: str) -> list[str]:
"""
provider_name: `openai` or `langgenius/openai/openai`
return: [`openai`, `langgenius/openai/openai`]
"""
provider_names = [provider_name]
model_provider_id = ModelProviderID(provider_name)
if model_provider_id.is_langgenius():
if "/" in provider_name:
provider_names.append(model_provider_id.provider_name)
else:
provider_names.append(str(model_provider_id))
return provider_names

@staticmethod @staticmethod
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]: def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
""" """
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
stmt = ( stmt = (
select(ProviderCredential) select(ProviderCredential)
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
.where(
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
)
.order_by(ProviderCredential.created_at.desc()) .order_by(ProviderCredential.created_at.desc())
) )


select(ProviderModelCredential) select(ProviderModelCredential)
.where( .where(
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name == provider_name,
ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
ProviderModelCredential.model_name == model_name, ProviderModelCredential.model_name == model_name,
ProviderModelCredential.model_type == model_type, ProviderModelCredential.model_type == model_type,
) )

Loading…
취소
저장