| @@ -150,6 +150,9 @@ class ProviderManager: | |||
| tenant_id | |||
| ) | |||
| # Get All provider model credentials | |||
| provider_name_to_provider_model_credentials_dict = self._get_all_provider_model_credentials(tenant_id) | |||
| provider_configurations = ProviderConfigurations(tenant_id=tenant_id) | |||
| # Construct ProviderConfiguration objects for each provider | |||
| @@ -171,10 +174,18 @@ class ProviderManager: | |||
| provider_model_records.extend( | |||
| provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, []) | |||
| ) | |||
| provider_model_credentials = provider_name_to_provider_model_credentials_dict.get( | |||
| provider_entity.provider, [] | |||
| ) | |||
| provider_id_entity = ModelProviderID(provider_name) | |||
| if provider_id_entity.is_langgenius(): | |||
| provider_model_credentials.extend( | |||
| provider_name_to_provider_model_credentials_dict.get(provider_id_entity.provider_name, []) | |||
| ) | |||
| # Convert to custom configuration | |||
| custom_configuration = self._to_custom_configuration( | |||
| tenant_id, provider_entity, provider_records, provider_model_records | |||
| tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials | |||
| ) | |||
| # Convert to system configuration | |||
| @@ -453,6 +464,24 @@ class ProviderManager: | |||
| ) | |||
| return provider_name_to_provider_model_settings_dict | |||
| @staticmethod | |||
| def _get_all_provider_model_credentials(tenant_id: str) -> dict[str, list[ProviderModelCredential]]: | |||
| """ | |||
| Get All provider model credentials of the workspace. | |||
| :param tenant_id: workspace id | |||
| :return: | |||
| """ | |||
| provider_name_to_provider_model_credentials_dict = defaultdict(list) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id) | |||
| provider_model_credentials = session.scalars(stmt) | |||
| for provider_model_credential in provider_model_credentials: | |||
| provider_name_to_provider_model_credentials_dict[provider_model_credential.provider_name].append( | |||
| provider_model_credential | |||
| ) | |||
| return provider_name_to_provider_model_credentials_dict | |||
| @staticmethod | |||
| def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: | |||
| """ | |||
| @@ -539,23 +568,6 @@ class ProviderManager: | |||
| for credential in available_credentials | |||
| ] | |||
| @staticmethod | |||
| def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]: | |||
| """ | |||
| Get all the credentials records from ProviderModelCredential by provider_name | |||
| :param tenant_id: workspace id | |||
| :param provider_name: provider name | |||
| """ | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(ProviderModelCredential).where( | |||
| ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name | |||
| ) | |||
| all_credentials = session.scalars(stmt).all() | |||
| return all_credentials | |||
| @staticmethod | |||
| def _init_trial_provider_records( | |||
| tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] | |||
| @@ -632,6 +644,7 @@ class ProviderManager: | |||
| provider_entity: ProviderEntity, | |||
| provider_records: list[Provider], | |||
| provider_model_records: list[ProviderModel], | |||
| provider_model_credentials: list[ProviderModelCredential], | |||
| ) -> CustomConfiguration: | |||
| """ | |||
| Convert to custom configuration. | |||
| @@ -647,15 +660,12 @@ class ProviderManager: | |||
| tenant_id, provider_entity, provider_records | |||
| ) | |||
| # Get all model credentials once | |||
| all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider) | |||
| # Get custom models which have not been added to the model list yet | |||
| unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials) | |||
| unadded_models = self._get_can_added_models(provider_model_records, provider_model_credentials) | |||
| # Get custom model configurations | |||
| custom_model_configurations = self._get_custom_model_configurations( | |||
| tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials | |||
| tenant_id, provider_entity, provider_model_records, unadded_models, provider_model_credentials | |||
| ) | |||
| can_added_models = [ | |||