|
|
|
@@ -29,6 +29,96 @@ class DatasourceProviderService: |
|
|
|
def __init__(self) -> None: |
|
|
|
self.provider_manager = PluginDatasourceManager() |
|
|
|
|
|
|
|
def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]: |
|
|
|
""" |
|
|
|
get default credentials |
|
|
|
""" |
|
|
|
with Session(db.engine) as session: |
|
|
|
datasource_provider = ( |
|
|
|
session.query(DatasourceProvider) |
|
|
|
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) |
|
|
|
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) |
|
|
|
.first() |
|
|
|
) |
|
|
|
if not datasource_provider: |
|
|
|
return {} |
|
|
|
return datasource_provider.encrypted_credentials |
|
|
|
|
|
|
|
def update_datasource_provider_name( |
|
|
|
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str |
|
|
|
): |
|
|
|
""" |
|
|
|
update datasource provider name |
|
|
|
""" |
|
|
|
with Session(db.engine) as session: |
|
|
|
target_provider = ( |
|
|
|
session.query(DatasourceProvider) |
|
|
|
.filter_by( |
|
|
|
tenant_id=tenant_id, |
|
|
|
id=credential_id, |
|
|
|
provider=datasource_provider_id.provider_name, |
|
|
|
plugin_id=datasource_provider_id.plugin_id, |
|
|
|
) |
|
|
|
.first() |
|
|
|
) |
|
|
|
if target_provider is None: |
|
|
|
raise ValueError("provider not found") |
|
|
|
|
|
|
|
if target_provider.name == name: |
|
|
|
return |
|
|
|
|
|
|
|
# check name is exist |
|
|
|
if ( |
|
|
|
session.query(DatasourceProvider) |
|
|
|
.filter_by( |
|
|
|
tenant_id=tenant_id, |
|
|
|
name=name, |
|
|
|
provider=datasource_provider_id.provider_name, |
|
|
|
plugin_id=datasource_provider_id.plugin_id, |
|
|
|
) |
|
|
|
.count() |
|
|
|
> 0 |
|
|
|
): |
|
|
|
raise ValueError("name is already exists") |
|
|
|
|
|
|
|
target_provider.name = name |
|
|
|
session.commit() |
|
|
|
return |
|
|
|
|
|
|
|
def set_default_datasource_provider( |
|
|
|
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str |
|
|
|
): |
|
|
|
""" |
|
|
|
set default datasource provider |
|
|
|
""" |
|
|
|
with Session(db.engine) as session: |
|
|
|
# get provider |
|
|
|
target_provider = ( |
|
|
|
session.query(DatasourceProvider) |
|
|
|
.filter_by( |
|
|
|
tenant_id=tenant_id, |
|
|
|
id=credential_id, |
|
|
|
provider=datasource_provider_id.provider_name, |
|
|
|
plugin_id=datasource_provider_id.plugin_id, |
|
|
|
) |
|
|
|
.first() |
|
|
|
) |
|
|
|
if target_provider is None: |
|
|
|
raise ValueError("provider not found") |
|
|
|
|
|
|
|
# clear default provider |
|
|
|
session.query(DatasourceProvider).filter_by( |
|
|
|
tenant_id=tenant_id, |
|
|
|
provider=target_provider.provider, |
|
|
|
plugin_id=target_provider.plugin_id, |
|
|
|
is_default=True, |
|
|
|
).update({"is_default": False}) |
|
|
|
|
|
|
|
# set new default provider |
|
|
|
target_provider.is_default = True |
|
|
|
session.commit() |
|
|
|
return {"result": "success"} |
|
|
|
|
|
|
|
def setup_oauth_custom_client_params( |
|
|
|
self, |
|
|
|
tenant_id: str, |
|
|
|
@@ -41,10 +131,6 @@ class DatasourceProviderService: |
|
|
|
""" |
|
|
|
if client_params is None and enabled is None: |
|
|
|
return |
|
|
|
provider_controller = PluginDatasourceManager() |
|
|
|
datasource_provider = provider_controller.fetch_datasource_provider( |
|
|
|
tenant_id=tenant_id, provider_id=str(datasource_provider_id) |
|
|
|
) |
|
|
|
with Session(db.engine) as session: |
|
|
|
tenant_oauth_client_params = ( |
|
|
|
session.query(DatasourceOauthTenantParamConfig) |
|
|
|
@@ -252,7 +338,7 @@ class DatasourceProviderService: |
|
|
|
) |
|
|
|
|
|
|
|
provider_credential_secret_variables = self.extract_secret_variables( |
|
|
|
tenant_id=tenant_id, provider_id=f"{provider_id}" |
|
|
|
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type.value |
|
|
|
) |
|
|
|
for key, value in credentials.items(): |
|
|
|
if key in provider_credential_secret_variables: |
|
|
|
@@ -310,7 +396,7 @@ class DatasourceProviderService: |
|
|
|
) |
|
|
|
if credential_valid: |
|
|
|
provider_credential_secret_variables = self.extract_secret_variables( |
|
|
|
tenant_id=tenant_id, provider_id=f"{provider_id}" |
|
|
|
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY.value |
|
|
|
) |
|
|
|
for key, value in credentials.items(): |
|
|
|
if key in provider_credential_secret_variables: |
|
|
|
@@ -329,7 +415,7 @@ class DatasourceProviderService: |
|
|
|
else: |
|
|
|
raise CredentialsValidateFailedError() |
|
|
|
|
|
|
|
def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]: |
|
|
|
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: str) -> list[str]: |
|
|
|
""" |
|
|
|
Extract secret input form variables. |
|
|
|
|
|
|
|
@@ -339,7 +425,16 @@ class DatasourceProviderService: |
|
|
|
datasource_provider = self.provider_manager.fetch_datasource_provider( |
|
|
|
tenant_id=tenant_id, provider_id=provider_id |
|
|
|
) |
|
|
|
credential_form_schemas = datasource_provider.declaration.credentials_schema |
|
|
|
credential_form_schemas = [] |
|
|
|
if credential_type == "api_key": |
|
|
|
credential_form_schemas = datasource_provider.declaration.credentials_schema |
|
|
|
elif credential_type == "oauth2": |
|
|
|
if not datasource_provider.declaration.oauth_schema: |
|
|
|
raise ValueError("Datasource provider oauth schema not found") |
|
|
|
credential_form_schemas = datasource_provider.declaration.oauth_schema.credentials_schema |
|
|
|
else: |
|
|
|
raise ValueError(f"Invalid credential type: {credential_type}") |
|
|
|
|
|
|
|
secret_input_form_variables = [] |
|
|
|
for credential_form_schema in credential_form_schemas: |
|
|
|
if credential_form_schema.type.value == FormType.SECRET_INPUT.value: |
|
|
|
@@ -368,11 +463,20 @@ class DatasourceProviderService: |
|
|
|
if not datasource_providers: |
|
|
|
return [] |
|
|
|
copy_credentials_list = [] |
|
|
|
default_provider = ( |
|
|
|
db.session.query(DatasourceProvider.id) |
|
|
|
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) |
|
|
|
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) |
|
|
|
.first() |
|
|
|
) |
|
|
|
default_provider_id = default_provider.id if default_provider else None |
|
|
|
for datasource_provider in datasource_providers: |
|
|
|
encrypted_credentials = datasource_provider.encrypted_credentials |
|
|
|
# Get provider credential secret variables |
|
|
|
credential_secret_variables = self.extract_secret_variables( |
|
|
|
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" |
|
|
|
tenant_id=tenant_id, |
|
|
|
provider_id=f"{plugin_id}/{provider}", |
|
|
|
credential_type=datasource_provider.auth_type, |
|
|
|
) |
|
|
|
|
|
|
|
# Obfuscate provider credentials |
|
|
|
@@ -387,6 +491,7 @@ class DatasourceProviderService: |
|
|
|
"name": datasource_provider.name, |
|
|
|
"avatar_url": datasource_provider.avatar_url, |
|
|
|
"id": datasource_provider.id, |
|
|
|
"is_default": default_provider_id and datasource_provider.id == default_provider_id, |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
@@ -469,7 +574,9 @@ class DatasourceProviderService: |
|
|
|
encrypted_credentials = datasource_provider.encrypted_credentials |
|
|
|
# Get provider credential secret variables |
|
|
|
credential_secret_variables = self.extract_secret_variables( |
|
|
|
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" |
|
|
|
tenant_id=tenant_id, |
|
|
|
provider_id=f"{plugin_id}/{provider}", |
|
|
|
credential_type=datasource_provider.auth_type, |
|
|
|
) |
|
|
|
|
|
|
|
# Obfuscate provider credentials |
|
|
|
@@ -507,12 +614,14 @@ class DatasourceProviderService: |
|
|
|
.first() |
|
|
|
) |
|
|
|
|
|
|
|
provider_credential_secret_variables = self.extract_secret_variables( |
|
|
|
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" |
|
|
|
) |
|
|
|
if not datasource_provider: |
|
|
|
raise ValueError("Datasource provider not found") |
|
|
|
else: |
|
|
|
provider_credential_secret_variables = self.extract_secret_variables( |
|
|
|
tenant_id=tenant_id, |
|
|
|
provider_id=f"{plugin_id}/{provider}", |
|
|
|
credential_type=datasource_provider.auth_type, |
|
|
|
) |
|
|
|
original_credentials = datasource_provider.encrypted_credentials |
|
|
|
for key, value in credentials.items(): |
|
|
|
if key in provider_credential_secret_variables: |