| @@ -7,6 +7,7 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | |||
| from core.tools.__base.tool import Tool | |||
| from core.tools.__base.tool_runtime import ToolRuntime | |||
| from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType | |||
| from core.tools.errors import ToolInvokeError | |||
| from extensions.ext_database import db | |||
| from factories.file_factory import build_from_mapping | |||
| from models.account import Account | |||
| @@ -96,11 +97,8 @@ class WorkflowTool(Tool): | |||
| assert isinstance(result, dict) | |||
| data = result.get("data", {}) | |||
| if data.get("error"): | |||
| raise Exception(data.get("error")) | |||
| if data.get("error"): | |||
| raise Exception(data.get("error")) | |||
| if err := data.get("error"): | |||
| raise ToolInvokeError(err) | |||
| outputs = data.get("outputs") | |||
| if outputs is None: | |||
| @@ -9,6 +9,7 @@ from core.file import File, FileTransferMethod | |||
| from core.plugin.manager.exc import PluginDaemonClientSideError | |||
| from core.plugin.manager.plugin import PluginInstallationManager | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter | |||
| from core.tools.errors import ToolInvokeError | |||
| from core.tools.tool_engine import ToolEngine | |||
| from core.tools.utils.message_transformer import ToolFileMessageTransformer | |||
| from core.variables.segments import ArrayAnySegment | |||
| @@ -119,13 +120,14 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| try: | |||
| # convert tool messages | |||
| yield from self._transform_message(message_stream, tool_info, parameters_for_log) | |||
| except PluginDaemonClientSideError as e: | |||
| except (PluginDaemonClientSideError, ToolInvokeError) as e: | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=parameters_for_log, | |||
| metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, | |||
| error=f"Failed to transform tool message: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| ) | |||
| @@ -0,0 +1,49 @@ | |||
| import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.tools.__base.tool_runtime import ToolRuntime | |||
| from core.tools.entities.common_entities import I18nObject | |||
| from core.tools.entities.tool_entities import ToolEntity, ToolIdentity | |||
| from core.tools.errors import ToolInvokeError | |||
| from core.tools.workflow_as_tool.tool import WorkflowTool | |||
| def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch): | |||
| """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when | |||
| `WorkflowAppGenerator.generate` returns a result with `error` key inside | |||
| the `data` element. | |||
| """ | |||
| entity = ToolEntity( | |||
| identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), | |||
| parameters=[], | |||
| description=None, | |||
| output_schema=None, | |||
| has_runtime_parameters=False, | |||
| ) | |||
| runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) | |||
| tool = WorkflowTool( | |||
| workflow_app_id="", | |||
| workflow_as_tool_id="", | |||
| version="1", | |||
| workflow_entities={}, | |||
| workflow_call_depth=1, | |||
| entity=entity, | |||
| runtime=runtime, | |||
| ) | |||
| # needs to patch those methods to avoid database access. | |||
| monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) | |||
| monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) | |||
| monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None) | |||
| # replace `WorkflowAppGenerator.generate` 's return value. | |||
| monkeypatch.setattr( | |||
| "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", | |||
| lambda *args, **kwargs: {"data": {"error": "oops"}}, | |||
| ) | |||
| with pytest.raises(ToolInvokeError) as exc_info: | |||
| # WorkflowTool always returns a generator, so we need to iterate to | |||
| # actually `run` the tool. | |||
| list(tool.invoke("test_user", {})) | |||
| assert exc_info.value.args == ("oops",) | |||
| @@ -0,0 +1 @@ | |||
| @@ -0,0 +1,110 @@ | |||
| from collections.abc import Generator | |||
| import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType | |||
| from core.tools.errors import ToolInvokeError | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState | |||
| from core.workflow.nodes.answer import AnswerStreamGenerateRoute | |||
| from core.workflow.nodes.end import EndStreamParam | |||
| from core.workflow.nodes.enums import ErrorStrategy | |||
| from core.workflow.nodes.event import RunCompletedEvent | |||
| from core.workflow.nodes.tool import ToolNode | |||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||
| from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType | |||
| def _create_tool_node(): | |||
| data = ToolNodeData( | |||
| title="Test Tool", | |||
| tool_parameters={}, | |||
| provider_id="test_tool", | |||
| provider_type=ToolProviderType.WORKFLOW, | |||
| provider_name="test tool", | |||
| tool_name="test tool", | |||
| tool_label="test tool", | |||
| tool_configurations={}, | |||
| plugin_unique_identifier=None, | |||
| desc="Exception handling test tool", | |||
| error_strategy=ErrorStrategy.FAIL_BRANCH, | |||
| version="1", | |||
| ) | |||
| variable_pool = VariablePool( | |||
| system_variables={}, | |||
| user_inputs={}, | |||
| ) | |||
| node = ToolNode( | |||
| id="1", | |||
| config={ | |||
| "id": "1", | |||
| "data": data.model_dump(), | |||
| }, | |||
| graph_init_params=GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="1", | |||
| graph_config={}, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| call_depth=0, | |||
| ), | |||
| graph=Graph( | |||
| root_node_id="1", | |||
| answer_stream_generate_routes=AnswerStreamGenerateRoute( | |||
| answer_dependencies={}, | |||
| answer_generate_route={}, | |||
| ), | |||
| end_stream_param=EndStreamParam( | |||
| end_dependencies={}, | |||
| end_stream_variable_selector_mapping={}, | |||
| ), | |||
| ), | |||
| graph_runtime_state=GraphRuntimeState( | |||
| variable_pool=variable_pool, | |||
| start_at=0, | |||
| ), | |||
| ) | |||
| return node | |||
| class MockToolRuntime: | |||
| def get_merged_runtime_parameters(self): | |||
| pass | |||
| def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]: | |||
| yield from [] | |||
| raise ToolInvokeError("oops") | |||
| def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch): | |||
| """Ensure that ToolNode can handle ToolInvokeError when transforming | |||
| messages generated by ToolEngine.generic_invoke. | |||
| """ | |||
| tool_node = _create_tool_node() | |||
| # Need to patch ToolManager and ToolEngine so that we don't | |||
| # have to set up a database. | |||
| monkeypatch.setattr( | |||
| "core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime() | |||
| ) | |||
| monkeypatch.setattr( | |||
| "core.tools.tool_engine.ToolEngine.generic_invoke", | |||
| lambda *args, **kwargs: mock_message_stream(), | |||
| ) | |||
| streams = list(tool_node._run()) | |||
| assert len(streams) == 1 | |||
| stream = streams[0] | |||
| assert isinstance(stream, RunCompletedEvent) | |||
| result = stream.run_result | |||
| assert isinstance(result, NodeRunResult) | |||
| assert result.status == WorkflowNodeExecutionStatus.FAILED | |||
| assert "oops" in result.error | |||
| assert "Failed to transform tool message:" in result.error | |||
| assert result.error_type == "ToolInvokeError" | |||