|
|
|
@@ -498,12 +498,16 @@ class ToolManageService: |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def test_api_tool_preview( |
|
|
|
tenant_id: str, tool_name: str, credentials: dict, parameters: dict, schema_type: str, schema: str |
|
|
|
tenant_id: str, |
|
|
|
provider_name: str, |
|
|
|
tool_name: str, |
|
|
|
credentials: dict, |
|
|
|
parameters: dict, |
|
|
|
schema_type: str, |
|
|
|
schema: str |
|
|
|
): |
|
|
|
""" |
|
|
|
test api tool before adding api tool provider |
|
|
|
|
|
|
|
1. parse schema into tool bundle |
|
|
|
""" |
|
|
|
if schema_type not in [member.value for member in ApiProviderSchemaType]: |
|
|
|
raise ValueError(f'invalid schema type {schema_type}') |
|
|
|
@@ -518,15 +522,21 @@ class ToolManageService: |
|
|
|
if tool_bundle is None: |
|
|
|
raise ValueError(f'invalid tool name {tool_name}') |
|
|
|
|
|
|
|
# create a fake db provider |
|
|
|
db_provider = ApiToolProvider( |
|
|
|
tenant_id='', user_id='', name='', icon='', |
|
|
|
schema=schema, |
|
|
|
description='', |
|
|
|
schema_type_str=ApiProviderSchemaType.OPENAPI.value, |
|
|
|
tools_str=serialize_base_model_array(tool_bundles), |
|
|
|
credentials_str=json.dumps(credentials), |
|
|
|
) |
|
|
|
db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( |
|
|
|
ApiToolProvider.tenant_id == tenant_id, |
|
|
|
ApiToolProvider.name == provider_name, |
|
|
|
).first() |
|
|
|
|
|
|
|
if not db_provider: |
|
|
|
# create a fake db provider |
|
|
|
db_provider = ApiToolProvider( |
|
|
|
tenant_id='', user_id='', name='', icon='', |
|
|
|
schema=schema, |
|
|
|
description='', |
|
|
|
schema_type_str=ApiProviderSchemaType.OPENAPI.value, |
|
|
|
tools_str=serialize_base_model_array(tool_bundles), |
|
|
|
credentials_str=json.dumps(credentials), |
|
|
|
) |
|
|
|
|
|
|
|
if 'auth_type' not in credentials: |
|
|
|
raise ValueError('auth_type is required') |
|
|
|
@@ -539,6 +549,19 @@ class ToolManageService: |
|
|
|
# load tools into provider entity |
|
|
|
provider_controller.load_bundled_tools(tool_bundles) |
|
|
|
|
|
|
|
# decrypt credentials |
|
|
|
if db_provider.id: |
|
|
|
tool_configuration = ToolConfiguration( |
|
|
|
tenant_id=tenant_id, |
|
|
|
provider_controller=provider_controller |
|
|
|
) |
|
|
|
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) |
|
|
|
# check if the credential has changed, save the original credential |
|
|
|
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) |
|
|
|
for name, value in credentials.items(): |
|
|
|
if name in masked_credentials and value == masked_credentials[name]: |
|
|
|
credentials[name] = decrypted_credentials[name] |
|
|
|
|
|
|
|
try: |
|
|
|
provider_controller.validate_credentials_format(credentials) |
|
|
|
# get tool |