|
|
|
@@ -754,7 +754,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
:param only_active: return active model only |
|
|
|
:return: |
|
|
|
""" |
|
|
|
provider_models = self.get_provider_models(model_type, only_active) |
|
|
|
provider_models = self.get_provider_models(model_type, only_active, model) |
|
|
|
|
|
|
|
for provider_model in provider_models: |
|
|
|
if provider_model.model == model: |
|
|
|
@@ -763,12 +763,13 @@ class ProviderConfiguration(BaseModel): |
|
|
|
return None |
|
|
|
|
|
|
|
def get_provider_models( |
|
|
|
self, model_type: Optional[ModelType] = None, only_active: bool = False |
|
|
|
self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None |
|
|
|
) -> list[ModelWithProviderEntity]: |
|
|
|
""" |
|
|
|
Get provider models. |
|
|
|
:param model_type: model type |
|
|
|
:param only_active: only active models |
|
|
|
:param model: model name |
|
|
|
:return: |
|
|
|
""" |
|
|
|
model_provider_factory = ModelProviderFactory(self.tenant_id) |
|
|
|
@@ -791,7 +792,10 @@ class ProviderConfiguration(BaseModel): |
|
|
|
) |
|
|
|
else: |
|
|
|
provider_models = self._get_custom_provider_models( |
|
|
|
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map |
|
|
|
model_types=model_types, |
|
|
|
provider_schema=provider_schema, |
|
|
|
model_setting_map=model_setting_map, |
|
|
|
model=model, |
|
|
|
) |
|
|
|
|
|
|
|
if only_active: |
|
|
|
@@ -943,6 +947,7 @@ class ProviderConfiguration(BaseModel): |
|
|
|
model_types: Sequence[ModelType], |
|
|
|
provider_schema: ProviderEntity, |
|
|
|
model_setting_map: dict[ModelType, dict[str, ModelSettings]], |
|
|
|
model: Optional[str] = None, |
|
|
|
) -> list[ModelWithProviderEntity]: |
|
|
|
""" |
|
|
|
Get custom provider models. |
|
|
|
@@ -995,7 +1000,8 @@ class ProviderConfiguration(BaseModel): |
|
|
|
for model_configuration in self.custom_configuration.models: |
|
|
|
if model_configuration.model_type not in model_types: |
|
|
|
continue |
|
|
|
|
|
|
|
if model and model != model_configuration.model: |
|
|
|
continue |
|
|
|
try: |
|
|
|
custom_model_schema = self.get_model_schema( |
|
|
|
model_type=model_configuration.model_type, |