Browse Source

fix: model provider credentials null value validate failed (#2009)

tags/0.4.7
takatost 1 year ago
parent
commit
1779cea6e3
No account linked to committer's email address

+ 2
- 13
api/core/entities/provider_configuration.py View File

@@ -165,7 +165,7 @@ class ProviderConfiguration(BaseModel):
if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])

model_provider_factory.provider_credentials_validate(
credentials = model_provider_factory.provider_credentials_validate(
self.provider.provider,
credentials
)
@@ -308,24 +308,13 @@ class ProviderConfiguration(BaseModel):
if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])

model_provider_factory.model_credentials_validate(
credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider,
model_type=model_type,
model=model,
credentials=credentials
)

model_schema = (
model_provider_factory.get_provider_instance(self.provider.provider)
.get_model_instance(model_type)._get_customizable_model_schema(
model=model,
credentials=credentials
)
)

if model_schema:
credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))

for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)

+ 10
- 6
api/core/model_runtime/model_providers/model_provider_factory.py View File

@@ -61,7 +61,7 @@ class ModelProviderFactory:
# return providers
return providers

def provider_credentials_validate(self, provider: str, credentials: dict) -> None:
def provider_credentials_validate(self, provider: str, credentials: dict) -> dict:
"""
Validate provider credentials

@@ -80,13 +80,15 @@ class ModelProviderFactory:

# validate provider credential schema
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
validator.validate_and_filter(credentials)
filtered_credentials = validator.validate_and_filter(credentials)

# validate the credentials, raise exception if validation failed
model_provider_instance.validate_provider_credentials(credentials)
model_provider_instance.validate_provider_credentials(filtered_credentials)

return filtered_credentials

def model_credentials_validate(self, provider: str, model_type: ModelType,
model: str, credentials: dict) -> None:
model: str, credentials: dict) -> dict:
"""
Validate model credentials

@@ -107,13 +109,15 @@ class ModelProviderFactory:

# validate model credential schema
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
validator.validate_and_filter(credentials)
filtered_credentials = validator.validate_and_filter(credentials)

# get model instance of the model type
model_instance = model_provider_instance.get_model_instance(model_type)

# call validate_credentials method of model type to validate credentials, raise exception if validation failed
model_instance.validate_credentials(model, credentials)
model_instance.validate_credentials(model, filtered_credentials)

return filtered_credentials

def get_models(self,
provider: Optional[str] = None,

+ 1
- 1
api/core/model_runtime/schema_validators/common_validator.py View File

@@ -46,7 +46,7 @@ class CommonValidator:
:return: validated credential form schema value
"""
# If the variable does not exist in credentials
if credential_form_schema.variable not in credentials:
if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]:
# If required is True, an exception is thrown
if credential_form_schema.required:
raise ValueError(f'Variable {credential_form_schema.variable} is required')

Loading…
Cancel
Save