Ver código fonte

fix: old custom model not display credential name (#25112)

tags/2.0.0-beta.1
非法操作 1 mês atrás
pai
commit
0a0ae16bd6
Nenhuma conta vinculada ao e-mail do autor do commit
1 arquivos alterados com 33 adições e 23 exclusões
  1. 33
    23
      api/core/provider_manager.py

+ 33
- 23
api/core/provider_manager.py Ver arquivo

tenant_id tenant_id
) )


# Get All provider model credentials
provider_name_to_provider_model_credentials_dict = self._get_all_provider_model_credentials(tenant_id)

provider_configurations = ProviderConfigurations(tenant_id=tenant_id) provider_configurations = ProviderConfigurations(tenant_id=tenant_id)


# Construct ProviderConfiguration objects for each provider # Construct ProviderConfiguration objects for each provider
provider_model_records.extend( provider_model_records.extend(
provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, []) provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, [])
) )
provider_model_credentials = provider_name_to_provider_model_credentials_dict.get(
provider_entity.provider, []
)
provider_id_entity = ModelProviderID(provider_name)
if provider_id_entity.is_langgenius():
provider_model_credentials.extend(
provider_name_to_provider_model_credentials_dict.get(provider_id_entity.provider_name, [])
)


# Convert to custom configuration # Convert to custom configuration
custom_configuration = self._to_custom_configuration( custom_configuration = self._to_custom_configuration(
tenant_id, provider_entity, provider_records, provider_model_records
tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials
) )


# Convert to system configuration # Convert to system configuration
) )
return provider_name_to_provider_model_settings_dict return provider_name_to_provider_model_settings_dict


@staticmethod
def _get_all_provider_model_credentials(tenant_id: str) -> dict[str, list[ProviderModelCredential]]:
"""
Get All provider model credentials of the workspace.

:param tenant_id: workspace id
:return:
"""
provider_name_to_provider_model_credentials_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
provider_model_credentials = session.scalars(stmt)
for provider_model_credential in provider_model_credentials:
provider_name_to_provider_model_credentials_dict[provider_model_credential.provider_name].append(
provider_model_credential
)
return provider_name_to_provider_model_credentials_dict

@staticmethod @staticmethod
def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
""" """
for credential in available_credentials for credential in available_credentials
] ]


@staticmethod
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
"""
Get all the credentials records from ProviderModelCredential by provider_name

:param tenant_id: workspace id
:param provider_name: provider name

"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
)

all_credentials = session.scalars(stmt).all()
return all_credentials

@staticmethod @staticmethod
def _init_trial_provider_records( def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
provider_entity: ProviderEntity, provider_entity: ProviderEntity,
provider_records: list[Provider], provider_records: list[Provider],
provider_model_records: list[ProviderModel], provider_model_records: list[ProviderModel],
provider_model_credentials: list[ProviderModelCredential],
) -> CustomConfiguration: ) -> CustomConfiguration:
""" """
Convert to custom configuration. Convert to custom configuration.
tenant_id, provider_entity, provider_records tenant_id, provider_entity, provider_records
) )


# Get all model credentials once
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)

# Get custom models which have not been added to the model list yet # Get custom models which have not been added to the model list yet
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
unadded_models = self._get_can_added_models(provider_model_records, provider_model_credentials)


# Get custom model configurations # Get custom model configurations
custom_model_configurations = self._get_custom_model_configurations( custom_model_configurations = self._get_custom_model_configurations(
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
tenant_id, provider_entity, provider_model_records, unadded_models, provider_model_credentials
) )


can_added_models = [ can_added_models = [

Carregando…
Cancelar
Salvar