| @@ -165,7 +165,7 @@ class ProviderConfiguration(BaseModel): | |||
| if value == '[__HIDDEN__]' and key in original_credentials: | |||
| credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | |||
| model_provider_factory.provider_credentials_validate( | |||
| credentials = model_provider_factory.provider_credentials_validate( | |||
| self.provider.provider, | |||
| credentials | |||
| ) | |||
| @@ -308,24 +308,13 @@ class ProviderConfiguration(BaseModel): | |||
| if value == '[__HIDDEN__]' and key in original_credentials: | |||
| credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | |||
| model_provider_factory.model_credentials_validate( | |||
| credentials = model_provider_factory.model_credentials_validate( | |||
| provider=self.provider.provider, | |||
| model_type=model_type, | |||
| model=model, | |||
| credentials=credentials | |||
| ) | |||
| model_schema = ( | |||
| model_provider_factory.get_provider_instance(self.provider.provider) | |||
| .get_model_instance(model_type)._get_customizable_model_schema( | |||
| model=model, | |||
| credentials=credentials | |||
| ) | |||
| ) | |||
| if model_schema: | |||
| credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema)) | |||
| for key, value in credentials.items(): | |||
| if key in provider_credential_secret_variables: | |||
| credentials[key] = encrypter.encrypt_token(self.tenant_id, value) | |||
| @@ -61,7 +61,7 @@ class ModelProviderFactory: | |||
| # return providers | |||
| return providers | |||
| def provider_credentials_validate(self, provider: str, credentials: dict) -> None: | |||
| def provider_credentials_validate(self, provider: str, credentials: dict) -> dict: | |||
| """ | |||
| Validate provider credentials | |||
| @@ -80,13 +80,15 @@ class ModelProviderFactory: | |||
| # validate provider credential schema | |||
| validator = ProviderCredentialSchemaValidator(provider_credential_schema) | |||
| validator.validate_and_filter(credentials) | |||
| filtered_credentials = validator.validate_and_filter(credentials) | |||
| # validate the credentials, raise exception if validation failed | |||
| model_provider_instance.validate_provider_credentials(credentials) | |||
| model_provider_instance.validate_provider_credentials(filtered_credentials) | |||
| return filtered_credentials | |||
| def model_credentials_validate(self, provider: str, model_type: ModelType, | |||
| model: str, credentials: dict) -> None: | |||
| model: str, credentials: dict) -> dict: | |||
| """ | |||
| Validate model credentials | |||
| @@ -107,13 +109,15 @@ class ModelProviderFactory: | |||
| # validate model credential schema | |||
| validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) | |||
| validator.validate_and_filter(credentials) | |||
| filtered_credentials = validator.validate_and_filter(credentials) | |||
| # get model instance of the model type | |||
| model_instance = model_provider_instance.get_model_instance(model_type) | |||
| # call validate_credentials method of model type to validate credentials, raise exception if validation failed | |||
| model_instance.validate_credentials(model, credentials) | |||
| model_instance.validate_credentials(model, filtered_credentials) | |||
| return filtered_credentials | |||
| def get_models(self, | |||
| provider: Optional[str] = None, | |||
| @@ -46,7 +46,7 @@ class CommonValidator: | |||
| :return: validated credential form schema value | |||
| """ | |||
| # If the variable does not exist in credentials | |||
| if credential_form_schema.variable not in credentials: | |||
| if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: | |||
| # If required is True, an exception is thrown | |||
| if credential_form_schema.required: | |||
| raise ValueError(f'Variable {credential_form_schema.variable} is required') | |||