Signed-off-by: -LAN- <laipz8200@outlook.com>tags/2.0.0-beta.1
| """ | |||||
| QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution. | |||||
| This engine uses a modular architecture with separated packages following | |||||
| Domain-Driven Design principles for improved maintainability and testability. | |||||
| """ | |||||
| import contextvars | |||||
| import logging | |||||
| import queue | |||||
| from collections.abc import Generator, Mapping | |||||
| from typing import final | |||||
| from flask import Flask, current_app | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||||
| from core.workflow.entities import GraphRuntimeState | |||||
| from core.workflow.enums import NodeExecutionType | |||||
| from core.workflow.graph import Graph | |||||
| from core.workflow.graph_events import ( | |||||
| GraphEngineEvent, | |||||
| GraphNodeEventBase, | |||||
| GraphRunAbortedEvent, | |||||
| GraphRunFailedEvent, | |||||
| GraphRunStartedEvent, | |||||
| GraphRunSucceededEvent, | |||||
| ) | |||||
| from models.enums import UserFrom | |||||
| from .command_processing import AbortCommandHandler, CommandProcessor | |||||
| from .domain import ExecutionContext, GraphExecution | |||||
| from .entities.commands import AbortCommand | |||||
| from .error_handling import ErrorHandler | |||||
| from .event_management import EventHandler, EventManager | |||||
| from .graph_traversal import EdgeProcessor, SkipPropagator | |||||
| from .layers.base import Layer | |||||
| from .orchestration import Dispatcher, ExecutionCoordinator | |||||
| from .protocols.command_channel import CommandChannel | |||||
| from .response_coordinator import ResponseStreamCoordinator | |||||
| from .state_management import UnifiedStateManager | |||||
| from .worker_management import SimpleWorkerPool | |||||
| logger = logging.getLogger(__name__) | |||||
| @final | |||||
| class GraphEngine: | |||||
| """ | |||||
| Queue-based graph execution engine. | |||||
| Uses a modular architecture that delegates responsibilities to specialized | |||||
| subsystems, following Domain-Driven Design and SOLID principles. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| tenant_id: str, | |||||
| app_id: str, | |||||
| workflow_id: str, | |||||
| user_id: str, | |||||
| user_from: UserFrom, | |||||
| invoke_from: InvokeFrom, | |||||
| call_depth: int, | |||||
| graph: Graph, | |||||
| graph_config: Mapping[str, object], | |||||
| graph_runtime_state: GraphRuntimeState, | |||||
| max_execution_steps: int, | |||||
| max_execution_time: int, | |||||
| command_channel: CommandChannel, | |||||
| min_workers: int | None = None, | |||||
| max_workers: int | None = None, | |||||
| scale_up_threshold: int | None = None, | |||||
| scale_down_idle_time: float | None = None, | |||||
| ) -> None: | |||||
| """Initialize the graph engine with all subsystems and dependencies.""" | |||||
| # === Domain Models === | |||||
| # Execution context encapsulates workflow execution metadata | |||||
| self._execution_context = ExecutionContext( | |||||
| tenant_id=tenant_id, | |||||
| app_id=app_id, | |||||
| workflow_id=workflow_id, | |||||
| user_id=user_id, | |||||
| user_from=user_from, | |||||
| invoke_from=invoke_from, | |||||
| call_depth=call_depth, | |||||
| max_execution_steps=max_execution_steps, | |||||
| max_execution_time=max_execution_time, | |||||
| ) | |||||
| # Graph execution tracks the overall execution state | |||||
| self._graph_execution = GraphExecution(workflow_id=workflow_id) | |||||
| # === Core Dependencies === | |||||
| # Graph structure and configuration | |||||
| self._graph = graph | |||||
| self._graph_config = graph_config | |||||
| self._graph_runtime_state = graph_runtime_state | |||||
| self._command_channel = command_channel | |||||
| # === Worker Management Parameters === | |||||
| # Parameters for dynamic worker pool scaling | |||||
| self._min_workers = min_workers | |||||
| self._max_workers = max_workers | |||||
| self._scale_up_threshold = scale_up_threshold | |||||
| self._scale_down_idle_time = scale_down_idle_time | |||||
| # === Execution Queues === | |||||
| # Queue for nodes ready to execute | |||||
| self._ready_queue: queue.Queue[str] = queue.Queue() | |||||
| # Queue for events generated during execution | |||||
| self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() | |||||
| # === State Management === | |||||
| # Unified state manager handles all node state transitions and queue operations | |||||
| self._state_manager = UnifiedStateManager(self._graph, self._ready_queue) | |||||
| # === Response Coordination === | |||||
| # Coordinates response streaming from response nodes | |||||
| self._response_coordinator = ResponseStreamCoordinator( | |||||
| variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph | |||||
| ) | |||||
| # === Event Management === | |||||
| # Event manager handles both collection and emission of events | |||||
| self._event_manager = EventManager() | |||||
| # === Error Handling === | |||||
| # Centralized error handler for graph execution errors | |||||
| self._error_handler = ErrorHandler(self._graph, self._graph_execution) | |||||
| # === Graph Traversal Components === | |||||
| # Propagates skip status through the graph when conditions aren't met | |||||
| self._skip_propagator = SkipPropagator( | |||||
| graph=self._graph, | |||||
| state_manager=self._state_manager, | |||||
| ) | |||||
| # Processes edges to determine next nodes after execution | |||||
| # Also handles conditional branching and route selection | |||||
| self._edge_processor = EdgeProcessor( | |||||
| graph=self._graph, | |||||
| state_manager=self._state_manager, | |||||
| response_coordinator=self._response_coordinator, | |||||
| skip_propagator=self._skip_propagator, | |||||
| ) | |||||
| # === Event Handler Registry === | |||||
| # Central registry for handling all node execution events | |||||
| self._event_handler_registry = EventHandler( | |||||
| graph=self._graph, | |||||
| graph_runtime_state=self._graph_runtime_state, | |||||
| graph_execution=self._graph_execution, | |||||
| response_coordinator=self._response_coordinator, | |||||
| event_collector=self._event_manager, | |||||
| edge_processor=self._edge_processor, | |||||
| state_manager=self._state_manager, | |||||
| error_handler=self._error_handler, | |||||
| ) | |||||
| # === Command Processing === | |||||
| # Processes external commands (e.g., abort requests) | |||||
| self._command_processor = CommandProcessor( | |||||
| command_channel=self._command_channel, | |||||
| graph_execution=self._graph_execution, | |||||
| ) | |||||
| # Register abort command handler | |||||
| abort_handler = AbortCommandHandler() | |||||
| self._command_processor.register_handler( | |||||
| AbortCommand, | |||||
| abort_handler, | |||||
| ) | |||||
| # === Worker Pool Setup === | |||||
| # Capture Flask app context for worker threads | |||||
| flask_app: Flask | None = None | |||||
| try: | |||||
| app = current_app._get_current_object() # type: ignore | |||||
| if isinstance(app, Flask): | |||||
| flask_app = app | |||||
| except RuntimeError: | |||||
| pass | |||||
| # Capture context variables for worker threads | |||||
| context_vars = contextvars.copy_context() | |||||
| # Create worker pool for parallel node execution | |||||
| self._worker_pool = SimpleWorkerPool( | |||||
| ready_queue=self._ready_queue, | |||||
| event_queue=self._event_queue, | |||||
| graph=self._graph, | |||||
| flask_app=flask_app, | |||||
| context_vars=context_vars, | |||||
| min_workers=self._min_workers, | |||||
| max_workers=self._max_workers, | |||||
| scale_up_threshold=self._scale_up_threshold, | |||||
| scale_down_idle_time=self._scale_down_idle_time, | |||||
| ) | |||||
| # === Orchestration === | |||||
| # Coordinates the overall execution lifecycle | |||||
| self._execution_coordinator = ExecutionCoordinator( | |||||
| graph_execution=self._graph_execution, | |||||
| state_manager=self._state_manager, | |||||
| event_handler=self._event_handler_registry, | |||||
| event_collector=self._event_manager, | |||||
| command_processor=self._command_processor, | |||||
| worker_pool=self._worker_pool, | |||||
| ) | |||||
| # Dispatches events and manages execution flow | |||||
| self._dispatcher = Dispatcher( | |||||
| event_queue=self._event_queue, | |||||
| event_handler=self._event_handler_registry, | |||||
| event_collector=self._event_manager, | |||||
| execution_coordinator=self._execution_coordinator, | |||||
| max_execution_time=self._execution_context.max_execution_time, | |||||
| event_emitter=self._event_manager, | |||||
| ) | |||||
| # === Extensibility === | |||||
| # Layers allow plugins to extend engine functionality | |||||
| self._layers: list[Layer] = [] | |||||
| # === Validation === | |||||
| # Ensure all nodes share the same GraphRuntimeState instance | |||||
| self._validate_graph_state_consistency() | |||||
| def _validate_graph_state_consistency(self) -> None: | |||||
| """Validate that all nodes share the same GraphRuntimeState.""" | |||||
| expected_state_id = id(self._graph_runtime_state) | |||||
| for node in self._graph.nodes.values(): | |||||
| if id(node.graph_runtime_state) != expected_state_id: | |||||
| raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") | |||||
| def layer(self, layer: Layer) -> "GraphEngine": | |||||
| """Add a layer for extending functionality.""" | |||||
| self._layers.append(layer) | |||||
| return self | |||||
| def run(self) -> Generator[GraphEngineEvent, None, None]: | |||||
| """ | |||||
| Execute the graph using the modular architecture. | |||||
| Returns: | |||||
| Generator yielding GraphEngineEvent instances | |||||
| """ | |||||
| try: | |||||
| # Initialize layers | |||||
| self._initialize_layers() | |||||
| # Start execution | |||||
| self._graph_execution.start() | |||||
| start_event = GraphRunStartedEvent() | |||||
| yield start_event | |||||
| # Start subsystems | |||||
| self._start_execution() | |||||
| # Yield events as they occur | |||||
| yield from self._event_manager.emit_events() | |||||
| # Handle completion | |||||
| if self._graph_execution.aborted: | |||||
| abort_reason = "Workflow execution aborted by user command" | |||||
| if self._graph_execution.error: | |||||
| abort_reason = str(self._graph_execution.error) | |||||
| yield GraphRunAbortedEvent( | |||||
| reason=abort_reason, | |||||
| outputs=self._graph_runtime_state.outputs, | |||||
| ) | |||||
| elif self._graph_execution.has_error: | |||||
| if self._graph_execution.error: | |||||
| raise self._graph_execution.error | |||||
| else: | |||||
| yield GraphRunSucceededEvent( | |||||
| outputs=self._graph_runtime_state.outputs, | |||||
| ) | |||||
| except Exception as e: | |||||
| yield GraphRunFailedEvent(error=str(e)) | |||||
| raise | |||||
| finally: | |||||
| self._stop_execution() | |||||
| def _initialize_layers(self) -> None: | |||||
| """Initialize layers with context.""" | |||||
| self._event_manager.set_layers(self._layers) | |||||
| for layer in self._layers: | |||||
| try: | |||||
| layer.initialize(self._graph_runtime_state, self._command_channel) | |||||
| except Exception as e: | |||||
| logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e) | |||||
| try: | |||||
| layer.on_graph_start() | |||||
| except Exception as e: | |||||
| logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e) | |||||
| def _start_execution(self) -> None: | |||||
| """Start execution subsystems.""" | |||||
| # Start worker pool (it calculates initial workers internally) | |||||
| self._worker_pool.start() | |||||
| # Register response nodes | |||||
| for node in self._graph.nodes.values(): | |||||
| if node.execution_type == NodeExecutionType.RESPONSE: | |||||
| self._response_coordinator.register(node.id) | |||||
| # Enqueue root node | |||||
| root_node = self._graph.root_node | |||||
| self._state_manager.enqueue_node(root_node.id) | |||||
| self._state_manager.start_execution(root_node.id) | |||||
| # Start dispatcher | |||||
| self._dispatcher.start() | |||||
| def _stop_execution(self) -> None: | |||||
| """Stop execution subsystems.""" | |||||
| self._dispatcher.stop() | |||||
| self._worker_pool.stop() | |||||
| # Don't mark complete here as the dispatcher already does it | |||||
| # Notify layers | |||||
| logger = logging.getLogger(__name__) | |||||
| for layer in self._layers: | |||||
| try: | |||||
| layer.on_graph_end(self._graph_execution.error) | |||||
| except Exception as e: | |||||
| logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e) | |||||
| # Public property accessors for attributes that need external access | |||||
| @property | |||||
| def graph_runtime_state(self) -> GraphRuntimeState: | |||||
| """Get the graph runtime state.""" | |||||
| return self._graph_runtime_state |
| from collections.abc import Generator, Mapping, Sequence | |||||
| from typing import TYPE_CHECKING, Any, Optional | |||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session | |||||
| from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler | |||||
| from core.file import File, FileTransferMethod | |||||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter | |||||
| from core.tools.errors import ToolInvokeError | |||||
| from core.tools.tool_engine import ToolEngine | |||||
| from core.tools.utils.message_transformer import ToolFileMessageTransformer | |||||
| from core.variables.segments import ArrayAnySegment, ArrayFileSegment | |||||
| from core.variables.variables import ArrayAnyVariable | |||||
| from core.workflow.enums import ( | |||||
| ErrorStrategy, | |||||
| NodeType, | |||||
| SystemVariableKey, | |||||
| WorkflowNodeExecutionMetadataKey, | |||||
| WorkflowNodeExecutionStatus, | |||||
| ) | |||||
| from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent | |||||
| from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig | |||||
| from core.workflow.nodes.base.node import Node | |||||
| from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser | |||||
| from extensions.ext_database import db | |||||
| from factories import file_factory | |||||
| from models import ToolFile | |||||
| from services.tools.builtin_tools_manage_service import BuiltinToolManageService | |||||
| from .entities import ToolNodeData | |||||
| from .exc import ( | |||||
| ToolFileError, | |||||
| ToolNodeError, | |||||
| ToolParameterError, | |||||
| ) | |||||
| if TYPE_CHECKING: | |||||
| from core.workflow.entities import VariablePool | |||||
| class ToolNode(Node): | |||||
| """ | |||||
| Tool Node | |||||
| """ | |||||
| node_type = NodeType.TOOL | |||||
| _node_data: ToolNodeData | |||||
| def init_node_data(self, data: Mapping[str, Any]) -> None: | |||||
| self._node_data = ToolNodeData.model_validate(data) | |||||
| @classmethod | |||||
| def version(cls) -> str: | |||||
| return "1" | |||||
| def _run(self) -> Generator: | |||||
| """ | |||||
| Run the tool node | |||||
| """ | |||||
| from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError | |||||
| node_data = self._node_data | |||||
| # fetch tool icon | |||||
| tool_info = { | |||||
| "provider_type": node_data.provider_type.value, | |||||
| "provider_id": node_data.provider_id, | |||||
| "plugin_unique_identifier": node_data.plugin_unique_identifier, | |||||
| } | |||||
| # get tool runtime | |||||
| try: | |||||
| from core.tools.tool_manager import ToolManager | |||||
| # This is an issue that caused problems before. | |||||
| # Logically, we shouldn't use the node_data.version field for judgment | |||||
| # But for backward compatibility with historical data | |||||
| # this version field judgment is still preserved here. | |||||
| variable_pool: VariablePool | None = None | |||||
| if node_data.version != "1" or node_data.tool_node_version != "1": | |||||
| variable_pool = self.graph_runtime_state.variable_pool | |||||
| tool_runtime = ToolManager.get_workflow_tool_runtime( | |||||
| self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool | |||||
| ) | |||||
| except ToolNodeError as e: | |||||
| yield StreamCompletedEvent( | |||||
| node_run_result=NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| inputs={}, | |||||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||||
| error=f"Failed to get tool runtime: {str(e)}", | |||||
| error_type=type(e).__name__, | |||||
| ) | |||||
| ) | |||||
| return | |||||
| # get parameters | |||||
| tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] | |||||
| parameters = self._generate_parameters( | |||||
| tool_parameters=tool_parameters, | |||||
| variable_pool=self.graph_runtime_state.variable_pool, | |||||
| node_data=self._node_data, | |||||
| ) | |||||
| parameters_for_log = self._generate_parameters( | |||||
| tool_parameters=tool_parameters, | |||||
| variable_pool=self.graph_runtime_state.variable_pool, | |||||
| node_data=self._node_data, | |||||
| for_log=True, | |||||
| ) | |||||
| # get conversation id | |||||
| conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) | |||||
| try: | |||||
| message_stream = ToolEngine.generic_invoke( | |||||
| tool=tool_runtime, | |||||
| tool_parameters=parameters, | |||||
| user_id=self.user_id, | |||||
| workflow_tool_callback=DifyWorkflowCallbackHandler(), | |||||
| workflow_call_depth=self.workflow_call_depth, | |||||
| app_id=self.app_id, | |||||
| conversation_id=conversation_id.text if conversation_id else None, | |||||
| ) | |||||
| except ToolNodeError as e: | |||||
| yield StreamCompletedEvent( | |||||
| node_run_result=NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| inputs=parameters_for_log, | |||||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||||
| error=f"Failed to invoke tool: {str(e)}", | |||||
| error_type=type(e).__name__, | |||||
| ) | |||||
| ) | |||||
| return | |||||
| try: | |||||
| # convert tool messages | |||||
| yield from self._transform_message( | |||||
| messages=message_stream, | |||||
| tool_info=tool_info, | |||||
| parameters_for_log=parameters_for_log, | |||||
| user_id=self.user_id, | |||||
| tenant_id=self.tenant_id, | |||||
| node_id=self._node_id, | |||||
| ) | |||||
| except ToolInvokeError as e: | |||||
| yield StreamCompletedEvent( | |||||
| node_run_result=NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| inputs=parameters_for_log, | |||||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||||
| error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}", | |||||
| error_type=type(e).__name__, | |||||
| ) | |||||
| ) | |||||
| except PluginInvokeError as e: | |||||
| yield StreamCompletedEvent( | |||||
| node_run_result=NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| inputs=parameters_for_log, | |||||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||||
| error="An error occurred in the plugin, " | |||||
| f"please contact the author of {node_data.provider_name} for help, " | |||||
| f"error type: {e.get_error_type()}, " | |||||
| f"error details: {e.get_error_message()}", | |||||
| error_type=type(e).__name__, | |||||
| ) | |||||
| ) | |||||
| except PluginDaemonClientSideError as e: | |||||
| yield StreamCompletedEvent( | |||||
| node_run_result=NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| inputs=parameters_for_log, | |||||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||||
| error=f"Failed to invoke tool, error: {e.description}", | |||||
| error_type=type(e).__name__, | |||||
| ) | |||||
| ) | |||||
| def _generate_parameters( | |||||
| self, | |||||
| *, | |||||
| tool_parameters: Sequence[ToolParameter], | |||||
| variable_pool: "VariablePool", | |||||
| node_data: ToolNodeData, | |||||
| for_log: bool = False, | |||||
| ) -> dict[str, Any]: | |||||
| """ | |||||
| Generate parameters based on the given tool parameters, variable pool, and node data. | |||||
| Args: | |||||
| tool_parameters (Sequence[ToolParameter]): The list of tool parameters. | |||||
| variable_pool (VariablePool): The variable pool containing the variables. | |||||
| node_data (ToolNodeData): The data associated with the tool node. | |||||
| Returns: | |||||
| Mapping[str, Any]: A dictionary containing the generated parameters. | |||||
| """ | |||||
| tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} | |||||
| result: dict[str, Any] = {} | |||||
| for parameter_name in node_data.tool_parameters: | |||||
| parameter = tool_parameters_dictionary.get(parameter_name) | |||||
| if not parameter: | |||||
| result[parameter_name] = None | |||||
| continue | |||||
| tool_input = node_data.tool_parameters[parameter_name] | |||||
| if tool_input.type == "variable": | |||||
| variable = variable_pool.get(tool_input.value) | |||||
| if variable is None: | |||||
| if parameter.required: | |||||
| raise ToolParameterError(f"Variable {tool_input.value} does not exist") | |||||
| continue | |||||
| parameter_value = variable.value | |||||
| elif tool_input.type in {"mixed", "constant"}: | |||||
| segment_group = variable_pool.convert_template(str(tool_input.value)) | |||||
| parameter_value = segment_group.log if for_log else segment_group.text | |||||
| else: | |||||
| raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") | |||||
| result[parameter_name] = parameter_value | |||||
| return result | |||||
| def _fetch_files(self, variable_pool: "VariablePool") -> list[File]: | |||||
| variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) | |||||
| assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) | |||||
| return list(variable.value) if variable else [] | |||||
| def _transform_message( | |||||
| self, | |||||
| messages: Generator[ToolInvokeMessage, None, None], | |||||
| tool_info: Mapping[str, Any], | |||||
| parameters_for_log: dict[str, Any], | |||||
| user_id: str, | |||||
| tenant_id: str, | |||||
| node_id: str, | |||||
| ) -> Generator: | |||||
| """ | |||||
| Convert ToolInvokeMessages into tuple[plain_text, files] | |||||
| """ | |||||
| # transform message and handle file storage | |||||
| from core.plugin.impl.plugin import PluginInstaller | |||||
| message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( | |||||
| messages=messages, | |||||
| user_id=user_id, | |||||
| tenant_id=tenant_id, | |||||
| conversation_id=None, | |||||
| ) | |||||
| text = "" | |||||
| files: list[File] = [] | |||||
| json: list[dict] = [] | |||||
| variables: dict[str, Any] = {} | |||||
| for message in message_stream: | |||||
| if message.type in { | |||||
| ToolInvokeMessage.MessageType.IMAGE_LINK, | |||||
| ToolInvokeMessage.MessageType.BINARY_LINK, | |||||
| ToolInvokeMessage.MessageType.IMAGE, | |||||
| }: | |||||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||||
| url = message.message.text | |||||
| if message.meta: | |||||
| transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) | |||||
| else: | |||||
| transfer_method = FileTransferMethod.TOOL_FILE | |||||
| tool_file_id = str(url).split("/")[-1].split(".")[0] | |||||
| with Session(db.engine) as session: | |||||
| stmt = select(ToolFile).where(ToolFile.id == tool_file_id) | |||||
| tool_file = session.scalar(stmt) | |||||
| if tool_file is None: | |||||
| raise ToolFileError(f"Tool file {tool_file_id} does not exist") | |||||
| mapping = { | |||||
| "tool_file_id": tool_file_id, | |||||
| "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), | |||||
| "transfer_method": transfer_method, | |||||
| "url": url, | |||||
| } | |||||
| file = file_factory.build_from_mapping( | |||||
| mapping=mapping, | |||||
| tenant_id=tenant_id, | |||||
| ) | |||||
| files.append(file) | |||||
| elif message.type == ToolInvokeMessage.MessageType.BLOB: | |||||
| # get tool file id | |||||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||||
| assert message.meta | |||||
| tool_file_id = message.message.text.split("/")[-1].split(".")[0] | |||||
| with Session(db.engine) as session: | |||||
| stmt = select(ToolFile).where(ToolFile.id == tool_file_id) | |||||
| tool_file = session.scalar(stmt) | |||||
| if tool_file is None: | |||||
| raise ToolFileError(f"tool file {tool_file_id} not exists") | |||||
| mapping = { | |||||
| "tool_file_id": tool_file_id, | |||||
| "transfer_method": FileTransferMethod.TOOL_FILE, | |||||
| } | |||||
| files.append( | |||||
| file_factory.build_from_mapping( | |||||
| mapping=mapping, | |||||
| tenant_id=tenant_id, | |||||
| ) | |||||
| ) | |||||
| elif message.type == ToolInvokeMessage.MessageType.TEXT: | |||||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||||
| text += message.message.text | |||||
| yield StreamChunkEvent( | |||||
| selector=[node_id, "text"], | |||||
| chunk=message.message.text, | |||||
| is_final=False, | |||||
| ) | |||||
| elif message.type == ToolInvokeMessage.MessageType.JSON: | |||||
| assert isinstance(message.message, ToolInvokeMessage.JsonMessage) | |||||
| # JSON message handling for tool node | |||||
| if message.message.json_object is not None: | |||||
| json.append(message.message.json_object) | |||||
| elif message.type == ToolInvokeMessage.MessageType.LINK: | |||||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||||
| stream_text = f"Link: {message.message.text}\n" | |||||
| text += stream_text | |||||
| yield StreamChunkEvent( | |||||
| selector=[node_id, "text"], | |||||
| chunk=stream_text, | |||||
| is_final=False, | |||||
| ) | |||||
| elif message.type == ToolInvokeMessage.MessageType.VARIABLE: | |||||
| assert isinstance(message.message, ToolInvokeMessage.VariableMessage) | |||||
| variable_name = message.message.variable_name | |||||
| variable_value = message.message.variable_value | |||||
| if message.message.stream: | |||||
| if not isinstance(variable_value, str): | |||||
| raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.") | |||||
| if variable_name not in variables: | |||||
| variables[variable_name] = "" | |||||
| variables[variable_name] += variable_value | |||||
| yield StreamChunkEvent( | |||||
| selector=[node_id, variable_name], | |||||
| chunk=variable_value, | |||||
| is_final=False, | |||||
| ) | |||||
| else: | |||||
| variables[variable_name] = variable_value | |||||
| elif message.type == ToolInvokeMessage.MessageType.FILE: | |||||
| assert message.meta is not None | |||||
| assert isinstance(message.meta, dict) | |||||
| # Validate that meta contains a 'file' key | |||||
| if "file" not in message.meta: | |||||
| raise ToolNodeError("File message is missing 'file' key in meta") | |||||
| # Validate that the file is an instance of File | |||||
| if not isinstance(message.meta["file"], File): | |||||
| raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") | |||||
| files.append(message.meta["file"]) | |||||
| elif message.type == ToolInvokeMessage.MessageType.LOG: | |||||
| assert isinstance(message.message, ToolInvokeMessage.LogMessage) | |||||
| if message.message.metadata: | |||||
| icon = tool_info.get("icon", "") | |||||
| dict_metadata = dict(message.message.metadata) | |||||
| if dict_metadata.get("provider"): | |||||
| manager = PluginInstaller() | |||||
| plugins = manager.list_plugins(tenant_id) | |||||
| try: | |||||
| current_plugin = next( | |||||
| plugin | |||||
| for plugin in plugins | |||||
| if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] | |||||
| ) | |||||
| icon = current_plugin.declaration.icon | |||||
| except StopIteration: | |||||
| pass | |||||
| icon_dark = None | |||||
| try: | |||||
| builtin_tool = next( | |||||
| provider | |||||
| for provider in BuiltinToolManageService.list_builtin_tools( | |||||
| user_id, | |||||
| tenant_id, | |||||
| ) | |||||
| if provider.name == dict_metadata["provider"] | |||||
| ) | |||||
| icon = builtin_tool.icon | |||||
| icon_dark = builtin_tool.icon_dark | |||||
| except StopIteration: | |||||
| pass | |||||
| dict_metadata["icon"] = icon | |||||
| dict_metadata["icon_dark"] = icon_dark | |||||
| message.message.metadata = dict_metadata | |||||
| # Add agent_logs to outputs['json'] to ensure frontend can access thinking process | |||||
| json_output: list[dict[str, Any]] = [] | |||||
| # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] | |||||
| if json: | |||||
| json_output.extend(json) | |||||
| else: | |||||
| json_output.append({"data": []}) | |||||
| # Send final chunk events for all streamed outputs | |||||
| # Final chunk for text stream | |||||
| yield StreamChunkEvent( | |||||
| selector=[self._node_id, "text"], | |||||
| chunk="", | |||||
| is_final=True, | |||||
| ) | |||||
| # Final chunks for any streamed variables | |||||
| for var_name in variables: | |||||
| yield StreamChunkEvent( | |||||
| selector=[self._node_id, var_name], | |||||
| chunk="", | |||||
| is_final=True, | |||||
| ) | |||||
| yield StreamCompletedEvent( | |||||
| node_run_result=NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||||
| outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, | |||||
| metadata={ | |||||
| WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, | |||||
| }, | |||||
| inputs=parameters_for_log, | |||||
| ) | |||||
| ) | |||||
| @classmethod | |||||
| def _extract_variable_selector_to_variable_mapping( | |||||
| cls, | |||||
| *, | |||||
| graph_config: Mapping[str, Any], | |||||
| node_id: str, | |||||
| node_data: Mapping[str, Any], | |||||
| ) -> Mapping[str, Sequence[str]]: | |||||
| """ | |||||
| Extract variable selector to variable mapping | |||||
| :param graph_config: graph config | |||||
| :param node_id: node id | |||||
| :param node_data: node data | |||||
| :return: | |||||
| """ | |||||
| # Create typed NodeData from dict | |||||
| typed_node_data = ToolNodeData.model_validate(node_data) | |||||
| result = {} | |||||
| for parameter_name in typed_node_data.tool_parameters: | |||||
| input = typed_node_data.tool_parameters[parameter_name] | |||||
| if input.type == "mixed": | |||||
| assert isinstance(input.value, str) | |||||
| selectors = VariableTemplateParser(input.value).extract_variable_selectors() | |||||
| for selector in selectors: | |||||
| result[selector.variable] = selector.value_selector | |||||
| elif input.type == "variable": | |||||
| result[parameter_name] = input.value | |||||
| elif input.type == "constant": | |||||
| pass | |||||
| result = {node_id + "." + key: value for key, value in result.items()} | |||||
| return result | |||||
| def _get_error_strategy(self) -> Optional[ErrorStrategy]: | |||||
| return self._node_data.error_strategy | |||||
| def _get_retry_config(self) -> RetryConfig: | |||||
| return self._node_data.retry_config | |||||
| def _get_title(self) -> str: | |||||
| return self._node_data.title | |||||
| def _get_description(self) -> Optional[str]: | |||||
| return self._node_data.desc | |||||
| def _get_default_value_dict(self) -> dict[str, Any]: | |||||
| return self._node_data.default_value_dict | |||||
| def get_base_node_data(self) -> BaseNodeData: | |||||
| return self._node_data | |||||
| @property | |||||
| def retry(self) -> bool: | |||||
| return self._node_data.retry_config.retry_enabled |
| import logging | |||||
| import time | |||||
| import uuid | |||||
| from collections.abc import Generator, Mapping, Sequence | |||||
| from typing import Any, Optional | |||||
| from configs import dify_config | |||||
| from core.app.apps.exc import GenerateTaskStoppedError | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||||
| from core.file.models import File | |||||
| from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID | |||||
| from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool | |||||
| from core.workflow.errors import WorkflowNodeRunFailedError | |||||
| from core.workflow.graph import Graph | |||||
| from core.workflow.graph_engine import GraphEngine | |||||
| from core.workflow.graph_engine.command_channels import InMemoryChannel | |||||
| from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer | |||||
| from core.workflow.graph_engine.protocols.command_channel import CommandChannel | |||||
| from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent | |||||
| from core.workflow.nodes import NodeType | |||||
| from core.workflow.nodes.base.node import Node | |||||
| from core.workflow.nodes.node_factory import DifyNodeFactory | |||||
| from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | |||||
| from core.workflow.system_variable import SystemVariable | |||||
| from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool | |||||
| from factories import file_factory | |||||
| from models.enums import UserFrom | |||||
| from models.workflow import Workflow | |||||
| logger = logging.getLogger(__name__) | |||||
| class WorkflowEntry: | |||||
| def __init__( | |||||
| self, | |||||
| tenant_id: str, | |||||
| app_id: str, | |||||
| workflow_id: str, | |||||
| graph_config: Mapping[str, Any], | |||||
| graph: Graph, | |||||
| user_id: str, | |||||
| user_from: UserFrom, | |||||
| invoke_from: InvokeFrom, | |||||
| call_depth: int, | |||||
| graph_runtime_state: GraphRuntimeState, | |||||
| command_channel: Optional[CommandChannel] = None, | |||||
| ) -> None: | |||||
| """ | |||||
| Init workflow entry | |||||
| :param tenant_id: tenant id | |||||
| :param app_id: app id | |||||
| :param workflow_id: workflow id | |||||
| :param workflow_type: workflow type | |||||
| :param graph_config: workflow graph config | |||||
| :param graph: workflow graph | |||||
| :param user_id: user id | |||||
| :param user_from: user from | |||||
| :param invoke_from: invoke from | |||||
| :param call_depth: call depth | |||||
| :param variable_pool: variable pool | |||||
| :param graph_runtime_state: pre-created graph runtime state | |||||
| :param command_channel: command channel for external control (optional, defaults to InMemoryChannel) | |||||
| :param thread_pool_id: thread pool id | |||||
| """ | |||||
| # check call depth | |||||
| workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH | |||||
| if call_depth > workflow_call_max_depth: | |||||
| raise ValueError(f"Max workflow call depth {workflow_call_max_depth} reached.") | |||||
| # Use provided command channel or default to InMemoryChannel | |||||
| if command_channel is None: | |||||
| command_channel = InMemoryChannel() | |||||
| self.command_channel = command_channel | |||||
| self.graph_engine = GraphEngine( | |||||
| tenant_id=tenant_id, | |||||
| app_id=app_id, | |||||
| workflow_id=workflow_id, | |||||
| user_id=user_id, | |||||
| user_from=user_from, | |||||
| invoke_from=invoke_from, | |||||
| call_depth=call_depth, | |||||
| graph=graph, | |||||
| graph_config=graph_config, | |||||
| graph_runtime_state=graph_runtime_state, | |||||
| max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, | |||||
| max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, | |||||
| command_channel=command_channel, | |||||
| ) | |||||
| # Add debug logging layer when in debug mode | |||||
| if dify_config.DEBUG: | |||||
| logger.info("Debug mode enabled - adding DebugLoggingLayer to GraphEngine") | |||||
| debug_layer = DebugLoggingLayer( | |||||
| level="DEBUG", | |||||
| include_inputs=True, | |||||
| include_outputs=True, | |||||
| include_process_data=False, # Process data can be very verbose | |||||
| logger_name=f"GraphEngine.Debug.{workflow_id[:8]}", # Use workflow ID prefix for unique logger | |||||
| ) | |||||
| self.graph_engine.layer(debug_layer) | |||||
| # Add execution limits layer | |||||
| limits_layer = ExecutionLimitsLayer( | |||||
| max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME | |||||
| ) | |||||
| self.graph_engine.layer(limits_layer) | |||||
| def run(self) -> Generator[GraphEngineEvent, None, None]: | |||||
| graph_engine = self.graph_engine | |||||
| try: | |||||
| # run workflow | |||||
| generator = graph_engine.run() | |||||
| yield from generator | |||||
| except GenerateTaskStoppedError: | |||||
| pass | |||||
| except Exception as e: | |||||
| logger.exception("Unknown Error when workflow entry running") | |||||
| yield GraphRunFailedEvent(error=str(e)) | |||||
| return | |||||
| @classmethod | |||||
| def single_step_run( | |||||
| cls, | |||||
| *, | |||||
| workflow: Workflow, | |||||
| node_id: str, | |||||
| user_id: str, | |||||
| user_inputs: Mapping[str, Any], | |||||
| variable_pool: VariablePool, | |||||
| variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | |||||
| ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]: | |||||
| """ | |||||
| Single step run workflow node | |||||
| :param workflow: Workflow instance | |||||
| :param node_id: node id | |||||
| :param user_id: user id | |||||
| :param user_inputs: user inputs | |||||
| :return: | |||||
| """ | |||||
| node_config = workflow.get_node_config_by_id(node_id) | |||||
| node_config_data = node_config.get("data", {}) | |||||
| # Get node class | |||||
| node_type = NodeType(node_config_data.get("type")) | |||||
| node_version = node_config_data.get("version", "1") | |||||
| node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] | |||||
| # init graph init params and runtime state | |||||
| graph_init_params = GraphInitParams( | |||||
| tenant_id=workflow.tenant_id, | |||||
| app_id=workflow.app_id, | |||||
| workflow_id=workflow.id, | |||||
| graph_config=workflow.graph_dict, | |||||
| user_id=user_id, | |||||
| user_from=UserFrom.ACCOUNT, | |||||
| invoke_from=InvokeFrom.DEBUGGER, | |||||
| call_depth=0, | |||||
| ) | |||||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | |||||
| # init node factory | |||||
| node_factory = DifyNodeFactory( | |||||
| graph_init_params=graph_init_params, | |||||
| graph_runtime_state=graph_runtime_state, | |||||
| ) | |||||
| # init graph | |||||
| graph = Graph.init(graph_config=workflow.graph_dict, node_factory=node_factory) | |||||
| # init workflow run state | |||||
| node = node_cls( | |||||
| id=str(uuid.uuid4()), | |||||
| config=node_config, | |||||
| graph_init_params=graph_init_params, | |||||
| graph_runtime_state=graph_runtime_state, | |||||
| ) | |||||
| node.init_node_data(node_config_data) | |||||
| try: | |||||
| # variable selector to variable mapping | |||||
| variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( | |||||
| graph_config=workflow.graph_dict, config=node_config | |||||
| ) | |||||
| except NotImplementedError: | |||||
| variable_mapping = {} | |||||
| # Loading missing variable from draft var here, and set it into | |||||
| # variable_pool. | |||||
| load_into_variable_pool( | |||||
| variable_loader=variable_loader, | |||||
| variable_pool=variable_pool, | |||||
| variable_mapping=variable_mapping, | |||||
| user_inputs=user_inputs, | |||||
| ) | |||||
| cls.mapping_user_inputs_to_variable_pool( | |||||
| variable_mapping=variable_mapping, | |||||
| user_inputs=user_inputs, | |||||
| variable_pool=variable_pool, | |||||
| tenant_id=workflow.tenant_id, | |||||
| ) | |||||
| try: | |||||
| # run node | |||||
| generator = node.run() | |||||
| except Exception as e: | |||||
| logger.exception( | |||||
| "error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s", | |||||
| workflow.id, | |||||
| node.id, | |||||
| node.node_type, | |||||
| node.version(), | |||||
| ) | |||||
| raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) | |||||
| return node, generator | |||||
| @staticmethod | |||||
| def _create_single_node_graph( | |||||
| node_id: str, | |||||
| node_data: dict[str, Any], | |||||
| node_width: int = 114, | |||||
| node_height: int = 514, | |||||
| ) -> dict[str, Any]: | |||||
| """ | |||||
| Create a minimal graph structure for testing a single node in isolation. | |||||
| :param node_id: ID of the target node | |||||
| :param node_data: configuration data for the target node | |||||
| :param node_width: width for UI layout (default: 200) | |||||
| :param node_height: height for UI layout (default: 100) | |||||
| :return: graph dictionary with start node and target node | |||||
| """ | |||||
| node_config = { | |||||
| "id": node_id, | |||||
| "width": node_width, | |||||
| "height": node_height, | |||||
| "type": "custom", | |||||
| "data": node_data, | |||||
| } | |||||
| start_node_config = { | |||||
| "id": "start", | |||||
| "width": node_width, | |||||
| "height": node_height, | |||||
| "type": "custom", | |||||
| "data": { | |||||
| "type": NodeType.START.value, | |||||
| "title": "Start", | |||||
| "desc": "Start", | |||||
| }, | |||||
| } | |||||
| return { | |||||
| "nodes": [start_node_config, node_config], | |||||
| "edges": [ | |||||
| { | |||||
| "source": "start", | |||||
| "target": node_id, | |||||
| "sourceHandle": "source", | |||||
| "targetHandle": "target", | |||||
| } | |||||
| ], | |||||
| } | |||||
| @classmethod | |||||
| def run_free_node( | |||||
| cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any] | |||||
| ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]: | |||||
| """ | |||||
| Run free node | |||||
| NOTE: only parameter_extractor/question_classifier are supported | |||||
| :param node_data: node data | |||||
| :param node_id: node id | |||||
| :param tenant_id: tenant id | |||||
| :param user_id: user id | |||||
| :param user_inputs: user inputs | |||||
| :return: | |||||
| """ | |||||
| # Create a minimal graph for single node execution | |||||
| graph_dict = cls._create_single_node_graph(node_id, node_data) | |||||
| node_type = NodeType(node_data.get("type", "")) | |||||
| if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}: | |||||
| raise ValueError(f"Node type {node_type} not supported") | |||||
| node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"] | |||||
| if not node_cls: | |||||
| raise ValueError(f"Node class not found for node type {node_type}") | |||||
| # init variable pool | |||||
| variable_pool = VariablePool( | |||||
| system_variables=SystemVariable.empty(), | |||||
| user_inputs={}, | |||||
| environment_variables=[], | |||||
| ) | |||||
| # init graph init params and runtime state | |||||
| graph_init_params = GraphInitParams( | |||||
| tenant_id=tenant_id, | |||||
| app_id="", | |||||
| workflow_id="", | |||||
| graph_config=graph_dict, | |||||
| user_id=user_id, | |||||
| user_from=UserFrom.ACCOUNT, | |||||
| invoke_from=InvokeFrom.DEBUGGER, | |||||
| call_depth=0, | |||||
| ) | |||||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | |||||
| # init node factory | |||||
| node_factory = DifyNodeFactory( | |||||
| graph_init_params=graph_init_params, | |||||
| graph_runtime_state=graph_runtime_state, | |||||
| ) | |||||
| # init graph | |||||
| graph = Graph.init(graph_config=graph_dict, node_factory=node_factory) | |||||
| node_cls = cast(type[Node], node_cls) | |||||
| # init workflow run state | |||||
| node_config = { | |||||
| "id": node_id, | |||||
| "data": node_data, | |||||
| } | |||||
| node: Node = node_cls( | |||||
| id=str(uuid.uuid4()), | |||||
| config=node_config, | |||||
| graph_init_params=graph_init_params, | |||||
| graph_runtime_state=graph_runtime_state, | |||||
| ) | |||||
| node.init_node_data(node_data) | |||||
| try: | |||||
| # variable selector to variable mapping | |||||
| try: | |||||
| variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( | |||||
| graph_config=graph_dict, config=node_config | |||||
| ) | |||||
| except NotImplementedError: | |||||
| variable_mapping = {} | |||||
| cls.mapping_user_inputs_to_variable_pool( | |||||
| variable_mapping=variable_mapping, | |||||
| user_inputs=user_inputs, | |||||
| variable_pool=variable_pool, | |||||
| tenant_id=tenant_id, | |||||
| ) | |||||
| # run node | |||||
| generator = node.run() | |||||
| return node, generator | |||||
| except Exception as e: | |||||
| logger.exception( | |||||
| "error while running node, node_id=%s, node_type=%s, node_version=%s", | |||||
| node.id, | |||||
| node.node_type, | |||||
| node.version(), | |||||
| ) | |||||
| raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) | |||||
| @staticmethod | |||||
| def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: | |||||
| # NOTE(QuantumGhost): Avoid using this function in new code. | |||||
| # Keep values structured as long as possible and only convert to dict | |||||
| # immediately before serialization (e.g., JSON serialization) to maintain | |||||
| # data integrity and type information. | |||||
| result = WorkflowEntry._handle_special_values(value) | |||||
| return result if isinstance(result, Mapping) or result is None else dict(result) | |||||
| @staticmethod | |||||
| def _handle_special_values(value: Any) -> Any: | |||||
| if value is None: | |||||
| return value | |||||
| if isinstance(value, dict): | |||||
| res = {} | |||||
| for k, v in value.items(): | |||||
| res[k] = WorkflowEntry._handle_special_values(v) | |||||
| return res | |||||
| if isinstance(value, list): | |||||
| res_list = [] | |||||
| for item in value: | |||||
| res_list.append(WorkflowEntry._handle_special_values(item)) | |||||
| return res_list | |||||
| if isinstance(value, File): | |||||
| return value.to_dict() | |||||
| return value | |||||
| @classmethod | |||||
| def mapping_user_inputs_to_variable_pool( | |||||
| cls, | |||||
| *, | |||||
| variable_mapping: Mapping[str, Sequence[str]], | |||||
| user_inputs: Mapping[str, Any], | |||||
| variable_pool: VariablePool, | |||||
| tenant_id: str, | |||||
| ) -> None: | |||||
| # NOTE(QuantumGhost): This logic should remain synchronized with | |||||
| # the implementation of `load_into_variable_pool`, specifically the logic about | |||||
| # variable existence checking. | |||||
| # WARNING(QuantumGhost): The semantics of this method are not clearly defined, | |||||
| # and multiple parts of the codebase depend on its current behavior. | |||||
| # Modify with caution. | |||||
| for node_variable, variable_selector in variable_mapping.items(): | |||||
| # fetch node id and variable key from node_variable | |||||
| node_variable_list = node_variable.split(".") | |||||
| if len(node_variable_list) < 1: | |||||
| raise ValueError(f"Invalid node variable {node_variable}") | |||||
| node_variable_key = ".".join(node_variable_list[1:]) | |||||
| if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get( | |||||
| variable_selector | |||||
| ): | |||||
| raise ValueError(f"Variable key {node_variable} not found in user inputs.") | |||||
| # environment variable already exist in variable pool, not from user inputs | |||||
| if variable_pool.get(variable_selector): | |||||
| continue | |||||
| # fetch variable node id from variable selector | |||||
| variable_node_id = variable_selector[0] | |||||
| variable_key_list = variable_selector[1:] | |||||
| variable_key_list = list(variable_key_list) | |||||
| # get input value | |||||
| input_value = user_inputs.get(node_variable) | |||||
| if not input_value: | |||||
| input_value = user_inputs.get(node_variable_key) | |||||
| if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value: | |||||
| input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id) | |||||
| if ( | |||||
| isinstance(input_value, list) | |||||
| and all(isinstance(item, dict) for item in input_value) | |||||
| and all("type" in item and "transfer_method" in item for item in input_value) | |||||
| ): | |||||
| input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id) | |||||
| # append variable and value to variable pool | |||||
| if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID: | |||||
| variable_pool.add([variable_node_id] + variable_key_list, input_value) |
| import time | |||||
| import uuid | |||||
| from os import getenv | |||||
| import pytest | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||||
| from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool | |||||
| from core.workflow.enums import WorkflowNodeExecutionStatus | |||||
| from core.workflow.graph import Graph | |||||
| from core.workflow.node_events import NodeRunResult | |||||
| from core.workflow.nodes.code.code_node import CodeNode | |||||
| from core.workflow.nodes.node_factory import DifyNodeFactory | |||||
| from core.workflow.system_variable import SystemVariable | |||||
| from models.enums import UserFrom | |||||
| from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock | |||||
| CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) | |||||
| def init_code_node(code_config: dict): | |||||
| graph_config = { | |||||
| "edges": [ | |||||
| { | |||||
| "id": "start-source-code-target", | |||||
| "source": "start", | |||||
| "target": "code", | |||||
| }, | |||||
| ], | |||||
| "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config], | |||||
| } | |||||
| init_params = GraphInitParams( | |||||
| tenant_id="1", | |||||
| app_id="1", | |||||
| workflow_id="1", | |||||
| graph_config=graph_config, | |||||
| user_id="1", | |||||
| user_from=UserFrom.ACCOUNT, | |||||
| invoke_from=InvokeFrom.DEBUGGER, | |||||
| call_depth=0, | |||||
| ) | |||||
| # construct variable pool | |||||
| variable_pool = VariablePool( | |||||
| system_variables=SystemVariable(user_id="aaa", files=[]), | |||||
| user_inputs={}, | |||||
| environment_variables=[], | |||||
| conversation_variables=[], | |||||
| ) | |||||
| variable_pool.add(["code", "args1"], 1) | |||||
| variable_pool.add(["code", "args2"], 2) | |||||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | |||||
| # Create node factory | |||||
| node_factory = DifyNodeFactory( | |||||
| graph_init_params=init_params, | |||||
| graph_runtime_state=graph_runtime_state, | |||||
| ) | |||||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory) | |||||
| node = CodeNode( | |||||
| id=str(uuid.uuid4()), | |||||
| config=code_config, | |||||
| graph_init_params=init_params, | |||||
| graph_runtime_state=graph_runtime_state, | |||||
| ) | |||||
| # Initialize node data | |||||
| if "data" in code_config: | |||||
| node.init_node_data(code_config["data"]) | |||||
| return node | |||||
| @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) | |||||
| def test_execute_code(setup_code_executor_mock): | |||||
| code = """ | |||||
| def main(args1: int, args2: int) -> dict: | |||||
| return { | |||||
| "result": args1 + args2, | |||||
| } | |||||
| """ | |||||
| # trim first 4 spaces at the beginning of each line | |||||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||||
| code_config = { | |||||
| "id": "code", | |||||
| "data": { | |||||
| "outputs": { | |||||
| "result": { | |||||
| "type": "number", | |||||
| }, | |||||
| }, | |||||
| "title": "123", | |||||
| "variables": [ | |||||
| { | |||||
| "variable": "args1", | |||||
| "value_selector": ["1", "args1"], | |||||
| }, | |||||
| {"variable": "args2", "value_selector": ["1", "args2"]}, | |||||
| ], | |||||
| "answer": "123", | |||||
| "code_language": "python3", | |||||
| "code": code, | |||||
| }, | |||||
| } | |||||
| node = init_code_node(code_config) | |||||
| node.graph_runtime_state.variable_pool.add(["1", "args1"], 1) | |||||
| node.graph_runtime_state.variable_pool.add(["1", "args2"], 2) | |||||
| # execute node | |||||
| result = node._run() | |||||
| assert isinstance(result, NodeRunResult) | |||||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||||
| assert result.outputs is not None | |||||
| assert result.outputs["result"] == 3 | |||||
| assert result.error == "" | |||||
| @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) | |||||
| def test_execute_code_output_validator(setup_code_executor_mock): | |||||
| code = """ | |||||
| def main(args1: int, args2: int) -> dict: | |||||
| return { | |||||
| "result": args1 + args2, | |||||
| } | |||||
| """ | |||||
| # trim first 4 spaces at the beginning of each line | |||||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||||
| code_config = { | |||||
| "id": "code", | |||||
| "data": { | |||||
| "outputs": { | |||||
| "result": { | |||||
| "type": "string", | |||||
| }, | |||||
| }, | |||||
| "title": "123", | |||||
| "variables": [ | |||||
| { | |||||
| "variable": "args1", | |||||
| "value_selector": ["1", "args1"], | |||||
| }, | |||||
| {"variable": "args2", "value_selector": ["1", "args2"]}, | |||||
| ], | |||||
| "answer": "123", | |||||
| "code_language": "python3", | |||||
| "code": code, | |||||
| }, | |||||
| } | |||||
| node = init_code_node(code_config) | |||||
| node.graph_runtime_state.variable_pool.add(["1", "args1"], 1) | |||||
| node.graph_runtime_state.variable_pool.add(["1", "args2"], 2) | |||||
| # execute node | |||||
| result = node._run() | |||||
| assert isinstance(result, NodeRunResult) | |||||
| assert result.status == WorkflowNodeExecutionStatus.FAILED | |||||
| assert result.error == "Output variable `result` must be a string" | |||||
| def test_execute_code_output_validator_depth(): | |||||
| code = """ | |||||
| def main(args1: int, args2: int) -> dict: | |||||
| return { | |||||
| "result": { | |||||
| "result": args1 + args2, | |||||
| } | |||||
| } | |||||
| """ | |||||
| # trim first 4 spaces at the beginning of each line | |||||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||||
| code_config = { | |||||
| "id": "code", | |||||
| "data": { | |||||
| "outputs": { | |||||
| "string_validator": { | |||||
| "type": "string", | |||||
| }, | |||||
| "number_validator": { | |||||
| "type": "number", | |||||
| }, | |||||
| "number_array_validator": { | |||||
| "type": "array[number]", | |||||
| }, | |||||
| "string_array_validator": { | |||||
| "type": "array[string]", | |||||
| }, | |||||
| "object_validator": { | |||||
| "type": "object", | |||||
| "children": { | |||||
| "result": { | |||||
| "type": "number", | |||||
| }, | |||||
| "depth": { | |||||
| "type": "object", | |||||
| "children": { | |||||
| "depth": { | |||||
| "type": "object", | |||||
| "children": { | |||||
| "depth": { | |||||
| "type": "number", | |||||
| } | |||||
| }, | |||||
| } | |||||
| }, | |||||
| }, | |||||
| }, | |||||
| }, | |||||
| }, | |||||
| "title": "123", | |||||
| "variables": [ | |||||
| { | |||||
| "variable": "args1", | |||||
| "value_selector": ["1", "args1"], | |||||
| }, | |||||
| {"variable": "args2", "value_selector": ["1", "args2"]}, | |||||
| ], | |||||
| "answer": "123", | |||||
| "code_language": "python3", | |||||
| "code": code, | |||||
| }, | |||||
| } | |||||
| node = init_code_node(code_config) | |||||
| # construct result | |||||
| result = { | |||||
| "number_validator": 1, | |||||
| "string_validator": "1", | |||||
| "number_array_validator": [1, 2, 3, 3.333], | |||||
| "string_array_validator": ["1", "2", "3"], | |||||
| "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, | |||||
| } | |||||
| # validate | |||||
| node._transform_result(result, node._node_data.outputs) | |||||
| # construct result | |||||
| result = { | |||||
| "number_validator": "1", | |||||
| "string_validator": 1, | |||||
| "number_array_validator": ["1", "2", "3", "3.333"], | |||||
| "string_array_validator": [1, 2, 3], | |||||
| "object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}}, | |||||
| } | |||||
| # validate | |||||
| with pytest.raises(ValueError): | |||||
| node._transform_result(result, node._node_data.outputs) | |||||
| # construct result | |||||
| result = { | |||||
| "number_validator": 1, | |||||
| "string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1", | |||||
| "number_array_validator": [1, 2, 3, 3.333], | |||||
| "string_array_validator": ["1", "2", "3"], | |||||
| "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, | |||||
| } | |||||
| # validate | |||||
| with pytest.raises(ValueError): | |||||
| node._transform_result(result, node._node_data.outputs) | |||||
| # construct result | |||||
| result = { | |||||
| "number_validator": 1, | |||||
| "string_validator": "1", | |||||
| "number_array_validator": [1, 2, 3, 3.333] * 2000, | |||||
| "string_array_validator": ["1", "2", "3"], | |||||
| "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, | |||||
| } | |||||
| # validate | |||||
| with pytest.raises(ValueError): | |||||
| node._transform_result(result, node._node_data.outputs) | |||||
| def test_execute_code_output_object_list(): | |||||
| code = """ | |||||
| def main(args1: int, args2: int) -> dict: | |||||
| return { | |||||
| "result": { | |||||
| "result": args1 + args2, | |||||
| } | |||||
| } | |||||
| """ | |||||
| # trim first 4 spaces at the beginning of each line | |||||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||||
| code_config = { | |||||
| "id": "code", | |||||
| "data": { | |||||
| "outputs": { | |||||
| "object_list": { | |||||
| "type": "array[object]", | |||||
| }, | |||||
| }, | |||||
| "title": "123", | |||||
| "variables": [ | |||||
| { | |||||
| "variable": "args1", | |||||
| "value_selector": ["1", "args1"], | |||||
| }, | |||||
| {"variable": "args2", "value_selector": ["1", "args2"]}, | |||||
| ], | |||||
| "answer": "123", | |||||
| "code_language": "python3", | |||||
| "code": code, | |||||
| }, | |||||
| } | |||||
| node = init_code_node(code_config) | |||||
| # construct result | |||||
| result = { | |||||
| "object_list": [ | |||||
| { | |||||
| "result": 1, | |||||
| }, | |||||
| { | |||||
| "result": 2, | |||||
| }, | |||||
| { | |||||
| "result": [1, 2, 3], | |||||
| }, | |||||
| ] | |||||
| } | |||||
| # validate | |||||
| node._transform_result(result, node._node_data.outputs) | |||||
| # construct result | |||||
| result = { | |||||
| "object_list": [ | |||||
| { | |||||
| "result": 1, | |||||
| }, | |||||
| { | |||||
| "result": 2, | |||||
| }, | |||||
| { | |||||
| "result": [1, 2, 3], | |||||
| }, | |||||
| 1, | |||||
| ] | |||||
| } | |||||
| # validate | |||||
| with pytest.raises(ValueError): | |||||
| node._transform_result(result, node._node_data.outputs) | |||||
| def test_execute_code_scientific_notation(): | |||||
| code = """ | |||||
| def main() -> dict: | |||||
| return { | |||||
| "result": -8.0E-5 | |||||
| } | |||||
| """ | |||||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||||
| code_config = { | |||||
| "id": "code", | |||||
| "data": { | |||||
| "outputs": { | |||||
| "result": { | |||||
| "type": "number", | |||||
| }, | |||||
| }, | |||||
| "title": "123", | |||||
| "variables": [], | |||||
| "answer": "123", | |||||
| "code_language": "python3", | |||||
| "code": code, | |||||
| }, | |||||
| } | |||||
| node = init_code_node(code_config) | |||||
| # execute node | |||||
| result = node._run() | |||||
| assert isinstance(result, NodeRunResult) | |||||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED |