| from core.tools.__base.tool import Tool | from core.tools.__base.tool import Tool | ||||
| from core.tools.__base.tool_runtime import ToolRuntime | from core.tools.__base.tool_runtime import ToolRuntime | ||||
| from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType | from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType | ||||
| from core.tools.errors import ToolInvokeError | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories.file_factory import build_from_mapping | from factories.file_factory import build_from_mapping | ||||
| from models.account import Account | from models.account import Account | ||||
| assert isinstance(result, dict) | assert isinstance(result, dict) | ||||
| data = result.get("data", {}) | data = result.get("data", {}) | ||||
| if data.get("error"): | |||||
| raise Exception(data.get("error")) | |||||
| if data.get("error"): | |||||
| raise Exception(data.get("error")) | |||||
| if err := data.get("error"): | |||||
| raise ToolInvokeError(err) | |||||
| outputs = data.get("outputs") | outputs = data.get("outputs") | ||||
| if outputs is None: | if outputs is None: |
| from core.plugin.manager.exc import PluginDaemonClientSideError | from core.plugin.manager.exc import PluginDaemonClientSideError | ||||
| from core.plugin.manager.plugin import PluginInstallationManager | from core.plugin.manager.plugin import PluginInstallationManager | ||||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter | from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter | ||||
| from core.tools.errors import ToolInvokeError | |||||
| from core.tools.tool_engine import ToolEngine | from core.tools.tool_engine import ToolEngine | ||||
| from core.tools.utils.message_transformer import ToolFileMessageTransformer | from core.tools.utils.message_transformer import ToolFileMessageTransformer | ||||
| from core.variables.segments import ArrayAnySegment | from core.variables.segments import ArrayAnySegment | ||||
| try: | try: | ||||
| # convert tool messages | # convert tool messages | ||||
| yield from self._transform_message(message_stream, tool_info, parameters_for_log) | yield from self._transform_message(message_stream, tool_info, parameters_for_log) | ||||
| except PluginDaemonClientSideError as e: | |||||
| except (PluginDaemonClientSideError, ToolInvokeError) as e: | |||||
| yield RunCompletedEvent( | yield RunCompletedEvent( | ||||
| run_result=NodeRunResult( | run_result=NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.FAILED, | status=WorkflowNodeExecutionStatus.FAILED, | ||||
| inputs=parameters_for_log, | inputs=parameters_for_log, | ||||
| metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, | metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, | ||||
| error=f"Failed to transform tool message: {str(e)}", | error=f"Failed to transform tool message: {str(e)}", | ||||
| error_type=type(e).__name__, | |||||
| ) | ) | ||||
| ) | ) | ||||
| import pytest | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||||
| from core.tools.__base.tool_runtime import ToolRuntime | |||||
| from core.tools.entities.common_entities import I18nObject | |||||
| from core.tools.entities.tool_entities import ToolEntity, ToolIdentity | |||||
| from core.tools.errors import ToolInvokeError | |||||
| from core.tools.workflow_as_tool.tool import WorkflowTool | |||||
| def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch): | |||||
| """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when | |||||
| `WorkflowAppGenerator.generate` returns a result with `error` key inside | |||||
| the `data` element. | |||||
| """ | |||||
| entity = ToolEntity( | |||||
| identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), | |||||
| parameters=[], | |||||
| description=None, | |||||
| output_schema=None, | |||||
| has_runtime_parameters=False, | |||||
| ) | |||||
| runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) | |||||
| tool = WorkflowTool( | |||||
| workflow_app_id="", | |||||
| workflow_as_tool_id="", | |||||
| version="1", | |||||
| workflow_entities={}, | |||||
| workflow_call_depth=1, | |||||
| entity=entity, | |||||
| runtime=runtime, | |||||
| ) | |||||
| # needs to patch those methods to avoid database access. | |||||
| monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) | |||||
| monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) | |||||
| monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None) | |||||
| # replace `WorkflowAppGenerator.generate` 's return value. | |||||
| monkeypatch.setattr( | |||||
| "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", | |||||
| lambda *args, **kwargs: {"data": {"error": "oops"}}, | |||||
| ) | |||||
| with pytest.raises(ToolInvokeError) as exc_info: | |||||
| # WorkflowTool always returns a generator, so we need to iterate to | |||||
| # actually `run` the tool. | |||||
| list(tool.invoke("test_user", {})) | |||||
| assert exc_info.value.args == ("oops",) |
| from collections.abc import Generator | |||||
| import pytest | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType | |||||
| from core.tools.errors import ToolInvokeError | |||||
| from core.workflow.entities.node_entities import NodeRunResult | |||||
| from core.workflow.entities.variable_pool import VariablePool | |||||
| from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState | |||||
| from core.workflow.nodes.answer import AnswerStreamGenerateRoute | |||||
| from core.workflow.nodes.end import EndStreamParam | |||||
| from core.workflow.nodes.enums import ErrorStrategy | |||||
| from core.workflow.nodes.event import RunCompletedEvent | |||||
| from core.workflow.nodes.tool import ToolNode | |||||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||||
| from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType | |||||
| def _create_tool_node(): | |||||
| data = ToolNodeData( | |||||
| title="Test Tool", | |||||
| tool_parameters={}, | |||||
| provider_id="test_tool", | |||||
| provider_type=ToolProviderType.WORKFLOW, | |||||
| provider_name="test tool", | |||||
| tool_name="test tool", | |||||
| tool_label="test tool", | |||||
| tool_configurations={}, | |||||
| plugin_unique_identifier=None, | |||||
| desc="Exception handling test tool", | |||||
| error_strategy=ErrorStrategy.FAIL_BRANCH, | |||||
| version="1", | |||||
| ) | |||||
| variable_pool = VariablePool( | |||||
| system_variables={}, | |||||
| user_inputs={}, | |||||
| ) | |||||
| node = ToolNode( | |||||
| id="1", | |||||
| config={ | |||||
| "id": "1", | |||||
| "data": data.model_dump(), | |||||
| }, | |||||
| graph_init_params=GraphInitParams( | |||||
| tenant_id="1", | |||||
| app_id="1", | |||||
| workflow_type=WorkflowType.WORKFLOW, | |||||
| workflow_id="1", | |||||
| graph_config={}, | |||||
| user_id="1", | |||||
| user_from=UserFrom.ACCOUNT, | |||||
| invoke_from=InvokeFrom.SERVICE_API, | |||||
| call_depth=0, | |||||
| ), | |||||
| graph=Graph( | |||||
| root_node_id="1", | |||||
| answer_stream_generate_routes=AnswerStreamGenerateRoute( | |||||
| answer_dependencies={}, | |||||
| answer_generate_route={}, | |||||
| ), | |||||
| end_stream_param=EndStreamParam( | |||||
| end_dependencies={}, | |||||
| end_stream_variable_selector_mapping={}, | |||||
| ), | |||||
| ), | |||||
| graph_runtime_state=GraphRuntimeState( | |||||
| variable_pool=variable_pool, | |||||
| start_at=0, | |||||
| ), | |||||
| ) | |||||
| return node | |||||
| class MockToolRuntime: | |||||
| def get_merged_runtime_parameters(self): | |||||
| pass | |||||
| def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]: | |||||
| yield from [] | |||||
| raise ToolInvokeError("oops") | |||||
| def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch): | |||||
| """Ensure that ToolNode can handle ToolInvokeError when transforming | |||||
| messages generated by ToolEngine.generic_invoke. | |||||
| """ | |||||
| tool_node = _create_tool_node() | |||||
| # Need to patch ToolManager and ToolEngine so that we don't | |||||
| # have to set up a database. | |||||
| monkeypatch.setattr( | |||||
| "core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime() | |||||
| ) | |||||
| monkeypatch.setattr( | |||||
| "core.tools.tool_engine.ToolEngine.generic_invoke", | |||||
| lambda *args, **kwargs: mock_message_stream(), | |||||
| ) | |||||
| streams = list(tool_node._run()) | |||||
| assert len(streams) == 1 | |||||
| stream = streams[0] | |||||
| assert isinstance(stream, RunCompletedEvent) | |||||
| result = stream.run_result | |||||
| assert isinstance(result, NodeRunResult) | |||||
| assert result.status == WorkflowNodeExecutionStatus.FAILED | |||||
| assert "oops" in result.error | |||||
| assert "Failed to transform tool message:" in result.error | |||||
| assert result.error_type == "ToolInvokeError" |