Przeglądaj źródła

fix: old custom model not display credential name (#25112)

tags/2.0.0-beta.1
非法操作 1 miesiąc temu
rodzic
commit
0a0ae16bd6
No account linked to committer's email address
1 zmienionych plików z 33 dodań i 23 usunięć
  1. 33
    23
      api/core/provider_manager.py

+ 33
- 23
api/core/provider_manager.py Wyświetl plik

@@ -150,6 +150,9 @@ class ProviderManager:
tenant_id
)

# Get All provider model credentials
provider_name_to_provider_model_credentials_dict = self._get_all_provider_model_credentials(tenant_id)

provider_configurations = ProviderConfigurations(tenant_id=tenant_id)

# Construct ProviderConfiguration objects for each provider
@@ -171,10 +174,18 @@ class ProviderManager:
provider_model_records.extend(
provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, [])
)
provider_model_credentials = provider_name_to_provider_model_credentials_dict.get(
provider_entity.provider, []
)
provider_id_entity = ModelProviderID(provider_name)
if provider_id_entity.is_langgenius():
provider_model_credentials.extend(
provider_name_to_provider_model_credentials_dict.get(provider_id_entity.provider_name, [])
)

# Convert to custom configuration
custom_configuration = self._to_custom_configuration(
tenant_id, provider_entity, provider_records, provider_model_records
tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials
)

# Convert to system configuration
@@ -453,6 +464,24 @@ class ProviderManager:
)
return provider_name_to_provider_model_settings_dict

@staticmethod
def _get_all_provider_model_credentials(tenant_id: str) -> dict[str, list[ProviderModelCredential]]:
"""
Get All provider model credentials of the workspace.

:param tenant_id: workspace id
:return:
"""
provider_name_to_provider_model_credentials_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
provider_model_credentials = session.scalars(stmt)
for provider_model_credential in provider_model_credentials:
provider_name_to_provider_model_credentials_dict[provider_model_credential.provider_name].append(
provider_model_credential
)
return provider_name_to_provider_model_credentials_dict

@staticmethod
def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
"""
@@ -539,23 +568,6 @@ class ProviderManager:
for credential in available_credentials
]

@staticmethod
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
"""
Get all the credentials records from ProviderModelCredential by provider_name

:param tenant_id: workspace id
:param provider_name: provider name

"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
)

all_credentials = session.scalars(stmt).all()
return all_credentials

@staticmethod
def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
@@ -632,6 +644,7 @@ class ProviderManager:
provider_entity: ProviderEntity,
provider_records: list[Provider],
provider_model_records: list[ProviderModel],
provider_model_credentials: list[ProviderModelCredential],
) -> CustomConfiguration:
"""
Convert to custom configuration.
@@ -647,15 +660,12 @@ class ProviderManager:
tenant_id, provider_entity, provider_records
)

# Get all model credentials once
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)

# Get custom models which have not been added to the model list yet
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
unadded_models = self._get_can_added_models(provider_model_records, provider_model_credentials)

# Get custom model configurations
custom_model_configurations = self._get_custom_model_configurations(
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
tenant_id, provider_entity, provider_model_records, unadded_models, provider_model_credentials
)

can_added_models = [

Ładowanie…
Anuluj
Zapisz