Browse Source

fix: perferred model provider not match with provider. (#18282)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/1.3.0
-LAN- 6 months ago
parent
commit
22a1bc337f
No account linked to committer's email address
1 changed files with 21 additions and 8 deletions
  1. 21
    8
      api/core/provider_manager.py

+ 21
- 8
api/core/provider_manager.py View File



# Get All preferred provider types of the workspace # Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
# Ensure that both the original provider name and its ModelProviderID string representation
# are present in the dictionary to handle cases where either form might be used
for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
provider_id = ModelProviderID(provider_name)
if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict:
# Add the ModelProviderID string representation if it's not already present
provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = (
provider_name_to_preferred_model_provider_records_dict[provider_name]
)


# Get All provider model settings # Get All provider model settings
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)


@staticmethod @staticmethod
def _init_trial_provider_records( def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]
) -> dict[str, list]:
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
) -> dict[str, list[Provider]]:
""" """
Initialize trial provider records if not exists. Initialize trial provider records if not exists.


if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try: try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
provider_record = Provider(
new_provider_record = Provider(
tenant_id=tenant_id, tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration. # TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name, provider_name=ModelProviderID(provider_name).provider_name,
quota_used=0, quota_used=0,
is_valid=True, is_valid=True,
) )
db.session.add(provider_record)
db.session.add(new_provider_record)
db.session.commit() db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
provider_record = (
existed_provider_record = (
db.session.query(Provider) db.session.query(Provider)
.filter( .filter(
Provider.tenant_id == tenant_id, Provider.tenant_id == tenant_id,
) )
.first() .first()
) )
if provider_record and not provider_record.is_valid:
provider_record.is_valid = True
if not existed_provider_record:
continue

if not existed_provider_record.is_valid:
existed_provider_record.is_valid = True
db.session.commit() db.session.commit()


provider_name_to_provider_records_dict[provider_name].append(provider_record)
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)


return provider_name_to_provider_records_dict return provider_name_to_provider_records_dict



Loading…
Cancel
Save