Signed-off-by: -LAN- <laipz8200@outlook.com>tags/2.0.0-beta.1
| @@ -1,339 +0,0 @@ | |||
| """ | |||
| QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution. | |||
| This engine uses a modular architecture with separated packages following | |||
| Domain-Driven Design principles for improved maintainability and testability. | |||
| """ | |||
| import contextvars | |||
| import logging | |||
| import queue | |||
| from collections.abc import Generator, Mapping | |||
| from typing import final | |||
| from flask import Flask, current_app | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities import GraphRuntimeState | |||
| from core.workflow.enums import NodeExecutionType | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_events import ( | |||
| GraphEngineEvent, | |||
| GraphNodeEventBase, | |||
| GraphRunAbortedEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunStartedEvent, | |||
| GraphRunSucceededEvent, | |||
| ) | |||
| from models.enums import UserFrom | |||
| from .command_processing import AbortCommandHandler, CommandProcessor | |||
| from .domain import ExecutionContext, GraphExecution | |||
| from .entities.commands import AbortCommand | |||
| from .error_handling import ErrorHandler | |||
| from .event_management import EventHandler, EventManager | |||
| from .graph_traversal import EdgeProcessor, SkipPropagator | |||
| from .layers.base import Layer | |||
| from .orchestration import Dispatcher, ExecutionCoordinator | |||
| from .protocols.command_channel import CommandChannel | |||
| from .response_coordinator import ResponseStreamCoordinator | |||
| from .state_management import UnifiedStateManager | |||
| from .worker_management import SimpleWorkerPool | |||
| logger = logging.getLogger(__name__) | |||
| @final | |||
| class GraphEngine: | |||
| """ | |||
| Queue-based graph execution engine. | |||
| Uses a modular architecture that delegates responsibilities to specialized | |||
| subsystems, following Domain-Driven Design and SOLID principles. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| tenant_id: str, | |||
| app_id: str, | |||
| workflow_id: str, | |||
| user_id: str, | |||
| user_from: UserFrom, | |||
| invoke_from: InvokeFrom, | |||
| call_depth: int, | |||
| graph: Graph, | |||
| graph_config: Mapping[str, object], | |||
| graph_runtime_state: GraphRuntimeState, | |||
| max_execution_steps: int, | |||
| max_execution_time: int, | |||
| command_channel: CommandChannel, | |||
| min_workers: int | None = None, | |||
| max_workers: int | None = None, | |||
| scale_up_threshold: int | None = None, | |||
| scale_down_idle_time: float | None = None, | |||
| ) -> None: | |||
| """Initialize the graph engine with all subsystems and dependencies.""" | |||
| # === Domain Models === | |||
| # Execution context encapsulates workflow execution metadata | |||
| self._execution_context = ExecutionContext( | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| workflow_id=workflow_id, | |||
| user_id=user_id, | |||
| user_from=user_from, | |||
| invoke_from=invoke_from, | |||
| call_depth=call_depth, | |||
| max_execution_steps=max_execution_steps, | |||
| max_execution_time=max_execution_time, | |||
| ) | |||
| # Graph execution tracks the overall execution state | |||
| self._graph_execution = GraphExecution(workflow_id=workflow_id) | |||
| # === Core Dependencies === | |||
| # Graph structure and configuration | |||
| self._graph = graph | |||
| self._graph_config = graph_config | |||
| self._graph_runtime_state = graph_runtime_state | |||
| self._command_channel = command_channel | |||
| # === Worker Management Parameters === | |||
| # Parameters for dynamic worker pool scaling | |||
| self._min_workers = min_workers | |||
| self._max_workers = max_workers | |||
| self._scale_up_threshold = scale_up_threshold | |||
| self._scale_down_idle_time = scale_down_idle_time | |||
| # === Execution Queues === | |||
| # Queue for nodes ready to execute | |||
| self._ready_queue: queue.Queue[str] = queue.Queue() | |||
| # Queue for events generated during execution | |||
| self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() | |||
| # === State Management === | |||
| # Unified state manager handles all node state transitions and queue operations | |||
| self._state_manager = UnifiedStateManager(self._graph, self._ready_queue) | |||
| # === Response Coordination === | |||
| # Coordinates response streaming from response nodes | |||
| self._response_coordinator = ResponseStreamCoordinator( | |||
| variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph | |||
| ) | |||
| # === Event Management === | |||
| # Event manager handles both collection and emission of events | |||
| self._event_manager = EventManager() | |||
| # === Error Handling === | |||
| # Centralized error handler for graph execution errors | |||
| self._error_handler = ErrorHandler(self._graph, self._graph_execution) | |||
| # === Graph Traversal Components === | |||
| # Propagates skip status through the graph when conditions aren't met | |||
| self._skip_propagator = SkipPropagator( | |||
| graph=self._graph, | |||
| state_manager=self._state_manager, | |||
| ) | |||
| # Processes edges to determine next nodes after execution | |||
| # Also handles conditional branching and route selection | |||
| self._edge_processor = EdgeProcessor( | |||
| graph=self._graph, | |||
| state_manager=self._state_manager, | |||
| response_coordinator=self._response_coordinator, | |||
| skip_propagator=self._skip_propagator, | |||
| ) | |||
| # === Event Handler Registry === | |||
| # Central registry for handling all node execution events | |||
| self._event_handler_registry = EventHandler( | |||
| graph=self._graph, | |||
| graph_runtime_state=self._graph_runtime_state, | |||
| graph_execution=self._graph_execution, | |||
| response_coordinator=self._response_coordinator, | |||
| event_collector=self._event_manager, | |||
| edge_processor=self._edge_processor, | |||
| state_manager=self._state_manager, | |||
| error_handler=self._error_handler, | |||
| ) | |||
| # === Command Processing === | |||
| # Processes external commands (e.g., abort requests) | |||
| self._command_processor = CommandProcessor( | |||
| command_channel=self._command_channel, | |||
| graph_execution=self._graph_execution, | |||
| ) | |||
| # Register abort command handler | |||
| abort_handler = AbortCommandHandler() | |||
| self._command_processor.register_handler( | |||
| AbortCommand, | |||
| abort_handler, | |||
| ) | |||
| # === Worker Pool Setup === | |||
| # Capture Flask app context for worker threads | |||
| flask_app: Flask | None = None | |||
| try: | |||
| app = current_app._get_current_object() # type: ignore | |||
| if isinstance(app, Flask): | |||
| flask_app = app | |||
| except RuntimeError: | |||
| pass | |||
| # Capture context variables for worker threads | |||
| context_vars = contextvars.copy_context() | |||
| # Create worker pool for parallel node execution | |||
| self._worker_pool = SimpleWorkerPool( | |||
| ready_queue=self._ready_queue, | |||
| event_queue=self._event_queue, | |||
| graph=self._graph, | |||
| flask_app=flask_app, | |||
| context_vars=context_vars, | |||
| min_workers=self._min_workers, | |||
| max_workers=self._max_workers, | |||
| scale_up_threshold=self._scale_up_threshold, | |||
| scale_down_idle_time=self._scale_down_idle_time, | |||
| ) | |||
| # === Orchestration === | |||
| # Coordinates the overall execution lifecycle | |||
| self._execution_coordinator = ExecutionCoordinator( | |||
| graph_execution=self._graph_execution, | |||
| state_manager=self._state_manager, | |||
| event_handler=self._event_handler_registry, | |||
| event_collector=self._event_manager, | |||
| command_processor=self._command_processor, | |||
| worker_pool=self._worker_pool, | |||
| ) | |||
| # Dispatches events and manages execution flow | |||
| self._dispatcher = Dispatcher( | |||
| event_queue=self._event_queue, | |||
| event_handler=self._event_handler_registry, | |||
| event_collector=self._event_manager, | |||
| execution_coordinator=self._execution_coordinator, | |||
| max_execution_time=self._execution_context.max_execution_time, | |||
| event_emitter=self._event_manager, | |||
| ) | |||
| # === Extensibility === | |||
| # Layers allow plugins to extend engine functionality | |||
| self._layers: list[Layer] = [] | |||
| # === Validation === | |||
| # Ensure all nodes share the same GraphRuntimeState instance | |||
| self._validate_graph_state_consistency() | |||
| def _validate_graph_state_consistency(self) -> None: | |||
| """Validate that all nodes share the same GraphRuntimeState.""" | |||
| expected_state_id = id(self._graph_runtime_state) | |||
| for node in self._graph.nodes.values(): | |||
| if id(node.graph_runtime_state) != expected_state_id: | |||
| raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") | |||
| def layer(self, layer: Layer) -> "GraphEngine": | |||
| """Add a layer for extending functionality.""" | |||
| self._layers.append(layer) | |||
| return self | |||
| def run(self) -> Generator[GraphEngineEvent, None, None]: | |||
| """ | |||
| Execute the graph using the modular architecture. | |||
| Returns: | |||
| Generator yielding GraphEngineEvent instances | |||
| """ | |||
| try: | |||
| # Initialize layers | |||
| self._initialize_layers() | |||
| # Start execution | |||
| self._graph_execution.start() | |||
| start_event = GraphRunStartedEvent() | |||
| yield start_event | |||
| # Start subsystems | |||
| self._start_execution() | |||
| # Yield events as they occur | |||
| yield from self._event_manager.emit_events() | |||
| # Handle completion | |||
| if self._graph_execution.aborted: | |||
| abort_reason = "Workflow execution aborted by user command" | |||
| if self._graph_execution.error: | |||
| abort_reason = str(self._graph_execution.error) | |||
| yield GraphRunAbortedEvent( | |||
| reason=abort_reason, | |||
| outputs=self._graph_runtime_state.outputs, | |||
| ) | |||
| elif self._graph_execution.has_error: | |||
| if self._graph_execution.error: | |||
| raise self._graph_execution.error | |||
| else: | |||
| yield GraphRunSucceededEvent( | |||
| outputs=self._graph_runtime_state.outputs, | |||
| ) | |||
| except Exception as e: | |||
| yield GraphRunFailedEvent(error=str(e)) | |||
| raise | |||
| finally: | |||
| self._stop_execution() | |||
| def _initialize_layers(self) -> None: | |||
| """Initialize layers with context.""" | |||
| self._event_manager.set_layers(self._layers) | |||
| for layer in self._layers: | |||
| try: | |||
| layer.initialize(self._graph_runtime_state, self._command_channel) | |||
| except Exception as e: | |||
| logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e) | |||
| try: | |||
| layer.on_graph_start() | |||
| except Exception as e: | |||
| logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e) | |||
| def _start_execution(self) -> None: | |||
| """Start execution subsystems.""" | |||
| # Start worker pool (it calculates initial workers internally) | |||
| self._worker_pool.start() | |||
| # Register response nodes | |||
| for node in self._graph.nodes.values(): | |||
| if node.execution_type == NodeExecutionType.RESPONSE: | |||
| self._response_coordinator.register(node.id) | |||
| # Enqueue root node | |||
| root_node = self._graph.root_node | |||
| self._state_manager.enqueue_node(root_node.id) | |||
| self._state_manager.start_execution(root_node.id) | |||
| # Start dispatcher | |||
| self._dispatcher.start() | |||
| def _stop_execution(self) -> None: | |||
| """Stop execution subsystems.""" | |||
| self._dispatcher.stop() | |||
| self._worker_pool.stop() | |||
| # Don't mark complete here as the dispatcher already does it | |||
| # Notify layers | |||
| logger = logging.getLogger(__name__) | |||
| for layer in self._layers: | |||
| try: | |||
| layer.on_graph_end(self._graph_execution.error) | |||
| except Exception as e: | |||
| logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e) | |||
| # Public property accessors for attributes that need external access | |||
| @property | |||
| def graph_runtime_state(self) -> GraphRuntimeState: | |||
| """Get the graph runtime state.""" | |||
| return self._graph_runtime_state | |||
| @@ -1,493 +0,0 @@ | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import TYPE_CHECKING, Any, Optional | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler | |||
| from core.file import File, FileTransferMethod | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter | |||
| from core.tools.errors import ToolInvokeError | |||
| from core.tools.tool_engine import ToolEngine | |||
| from core.tools.utils.message_transformer import ToolFileMessageTransformer | |||
| from core.variables.segments import ArrayAnySegment, ArrayFileSegment | |||
| from core.variables.variables import ArrayAnyVariable | |||
| from core.workflow.enums import ( | |||
| ErrorStrategy, | |||
| NodeType, | |||
| SystemVariableKey, | |||
| WorkflowNodeExecutionMetadataKey, | |||
| WorkflowNodeExecutionStatus, | |||
| ) | |||
| from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent | |||
| from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig | |||
| from core.workflow.nodes.base.node import Node | |||
| from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models import ToolFile | |||
| from services.tools.builtin_tools_manage_service import BuiltinToolManageService | |||
| from .entities import ToolNodeData | |||
| from .exc import ( | |||
| ToolFileError, | |||
| ToolNodeError, | |||
| ToolParameterError, | |||
| ) | |||
| if TYPE_CHECKING: | |||
| from core.workflow.entities import VariablePool | |||
| class ToolNode(Node): | |||
| """ | |||
| Tool Node | |||
| """ | |||
| node_type = NodeType.TOOL | |||
| _node_data: ToolNodeData | |||
| def init_node_data(self, data: Mapping[str, Any]) -> None: | |||
| self._node_data = ToolNodeData.model_validate(data) | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> Generator: | |||
| """ | |||
| Run the tool node | |||
| """ | |||
| from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError | |||
| node_data = self._node_data | |||
| # fetch tool icon | |||
| tool_info = { | |||
| "provider_type": node_data.provider_type.value, | |||
| "provider_id": node_data.provider_id, | |||
| "plugin_unique_identifier": node_data.plugin_unique_identifier, | |||
| } | |||
| # get tool runtime | |||
| try: | |||
| from core.tools.tool_manager import ToolManager | |||
| # This is an issue that caused problems before. | |||
| # Logically, we shouldn't use the node_data.version field for judgment | |||
| # But for backward compatibility with historical data | |||
| # this version field judgment is still preserved here. | |||
| variable_pool: VariablePool | None = None | |||
| if node_data.version != "1" or node_data.tool_node_version != "1": | |||
| variable_pool = self.graph_runtime_state.variable_pool | |||
| tool_runtime = ToolManager.get_workflow_tool_runtime( | |||
| self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool | |||
| ) | |||
| except ToolNodeError as e: | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs={}, | |||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||
| error=f"Failed to get tool runtime: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| ) | |||
| return | |||
| # get parameters | |||
| tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] | |||
| parameters = self._generate_parameters( | |||
| tool_parameters=tool_parameters, | |||
| variable_pool=self.graph_runtime_state.variable_pool, | |||
| node_data=self._node_data, | |||
| ) | |||
| parameters_for_log = self._generate_parameters( | |||
| tool_parameters=tool_parameters, | |||
| variable_pool=self.graph_runtime_state.variable_pool, | |||
| node_data=self._node_data, | |||
| for_log=True, | |||
| ) | |||
| # get conversation id | |||
| conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) | |||
| try: | |||
| message_stream = ToolEngine.generic_invoke( | |||
| tool=tool_runtime, | |||
| tool_parameters=parameters, | |||
| user_id=self.user_id, | |||
| workflow_tool_callback=DifyWorkflowCallbackHandler(), | |||
| workflow_call_depth=self.workflow_call_depth, | |||
| app_id=self.app_id, | |||
| conversation_id=conversation_id.text if conversation_id else None, | |||
| ) | |||
| except ToolNodeError as e: | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=parameters_for_log, | |||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||
| error=f"Failed to invoke tool: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| ) | |||
| return | |||
| try: | |||
| # convert tool messages | |||
| yield from self._transform_message( | |||
| messages=message_stream, | |||
| tool_info=tool_info, | |||
| parameters_for_log=parameters_for_log, | |||
| user_id=self.user_id, | |||
| tenant_id=self.tenant_id, | |||
| node_id=self._node_id, | |||
| ) | |||
| except ToolInvokeError as e: | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=parameters_for_log, | |||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||
| error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| ) | |||
| except PluginInvokeError as e: | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=parameters_for_log, | |||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||
| error="An error occurred in the plugin, " | |||
| f"please contact the author of {node_data.provider_name} for help, " | |||
| f"error type: {e.get_error_type()}, " | |||
| f"error details: {e.get_error_message()}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| ) | |||
| except PluginDaemonClientSideError as e: | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=parameters_for_log, | |||
| metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, | |||
| error=f"Failed to invoke tool, error: {e.description}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| ) | |||
| def _generate_parameters( | |||
| self, | |||
| *, | |||
| tool_parameters: Sequence[ToolParameter], | |||
| variable_pool: "VariablePool", | |||
| node_data: ToolNodeData, | |||
| for_log: bool = False, | |||
| ) -> dict[str, Any]: | |||
| """ | |||
| Generate parameters based on the given tool parameters, variable pool, and node data. | |||
| Args: | |||
| tool_parameters (Sequence[ToolParameter]): The list of tool parameters. | |||
| variable_pool (VariablePool): The variable pool containing the variables. | |||
| node_data (ToolNodeData): The data associated with the tool node. | |||
| Returns: | |||
| Mapping[str, Any]: A dictionary containing the generated parameters. | |||
| """ | |||
| tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} | |||
| result: dict[str, Any] = {} | |||
| for parameter_name in node_data.tool_parameters: | |||
| parameter = tool_parameters_dictionary.get(parameter_name) | |||
| if not parameter: | |||
| result[parameter_name] = None | |||
| continue | |||
| tool_input = node_data.tool_parameters[parameter_name] | |||
| if tool_input.type == "variable": | |||
| variable = variable_pool.get(tool_input.value) | |||
| if variable is None: | |||
| if parameter.required: | |||
| raise ToolParameterError(f"Variable {tool_input.value} does not exist") | |||
| continue | |||
| parameter_value = variable.value | |||
| elif tool_input.type in {"mixed", "constant"}: | |||
| segment_group = variable_pool.convert_template(str(tool_input.value)) | |||
| parameter_value = segment_group.log if for_log else segment_group.text | |||
| else: | |||
| raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") | |||
| result[parameter_name] = parameter_value | |||
| return result | |||
| def _fetch_files(self, variable_pool: "VariablePool") -> list[File]: | |||
| variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) | |||
| assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) | |||
| return list(variable.value) if variable else [] | |||
| def _transform_message( | |||
| self, | |||
| messages: Generator[ToolInvokeMessage, None, None], | |||
| tool_info: Mapping[str, Any], | |||
| parameters_for_log: dict[str, Any], | |||
| user_id: str, | |||
| tenant_id: str, | |||
| node_id: str, | |||
| ) -> Generator: | |||
| """ | |||
| Convert ToolInvokeMessages into tuple[plain_text, files] | |||
| """ | |||
| # transform message and handle file storage | |||
| from core.plugin.impl.plugin import PluginInstaller | |||
| message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( | |||
| messages=messages, | |||
| user_id=user_id, | |||
| tenant_id=tenant_id, | |||
| conversation_id=None, | |||
| ) | |||
| text = "" | |||
| files: list[File] = [] | |||
| json: list[dict] = [] | |||
| variables: dict[str, Any] = {} | |||
| for message in message_stream: | |||
| if message.type in { | |||
| ToolInvokeMessage.MessageType.IMAGE_LINK, | |||
| ToolInvokeMessage.MessageType.BINARY_LINK, | |||
| ToolInvokeMessage.MessageType.IMAGE, | |||
| }: | |||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||
| url = message.message.text | |||
| if message.meta: | |||
| transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) | |||
| else: | |||
| transfer_method = FileTransferMethod.TOOL_FILE | |||
| tool_file_id = str(url).split("/")[-1].split(".")[0] | |||
| with Session(db.engine) as session: | |||
| stmt = select(ToolFile).where(ToolFile.id == tool_file_id) | |||
| tool_file = session.scalar(stmt) | |||
| if tool_file is None: | |||
| raise ToolFileError(f"Tool file {tool_file_id} does not exist") | |||
| mapping = { | |||
| "tool_file_id": tool_file_id, | |||
| "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), | |||
| "transfer_method": transfer_method, | |||
| "url": url, | |||
| } | |||
| file = file_factory.build_from_mapping( | |||
| mapping=mapping, | |||
| tenant_id=tenant_id, | |||
| ) | |||
| files.append(file) | |||
| elif message.type == ToolInvokeMessage.MessageType.BLOB: | |||
| # get tool file id | |||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||
| assert message.meta | |||
| tool_file_id = message.message.text.split("/")[-1].split(".")[0] | |||
| with Session(db.engine) as session: | |||
| stmt = select(ToolFile).where(ToolFile.id == tool_file_id) | |||
| tool_file = session.scalar(stmt) | |||
| if tool_file is None: | |||
| raise ToolFileError(f"tool file {tool_file_id} not exists") | |||
| mapping = { | |||
| "tool_file_id": tool_file_id, | |||
| "transfer_method": FileTransferMethod.TOOL_FILE, | |||
| } | |||
| files.append( | |||
| file_factory.build_from_mapping( | |||
| mapping=mapping, | |||
| tenant_id=tenant_id, | |||
| ) | |||
| ) | |||
| elif message.type == ToolInvokeMessage.MessageType.TEXT: | |||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||
| text += message.message.text | |||
| yield StreamChunkEvent( | |||
| selector=[node_id, "text"], | |||
| chunk=message.message.text, | |||
| is_final=False, | |||
| ) | |||
| elif message.type == ToolInvokeMessage.MessageType.JSON: | |||
| assert isinstance(message.message, ToolInvokeMessage.JsonMessage) | |||
| # JSON message handling for tool node | |||
| if message.message.json_object is not None: | |||
| json.append(message.message.json_object) | |||
| elif message.type == ToolInvokeMessage.MessageType.LINK: | |||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||
| stream_text = f"Link: {message.message.text}\n" | |||
| text += stream_text | |||
| yield StreamChunkEvent( | |||
| selector=[node_id, "text"], | |||
| chunk=stream_text, | |||
| is_final=False, | |||
| ) | |||
| elif message.type == ToolInvokeMessage.MessageType.VARIABLE: | |||
| assert isinstance(message.message, ToolInvokeMessage.VariableMessage) | |||
| variable_name = message.message.variable_name | |||
| variable_value = message.message.variable_value | |||
| if message.message.stream: | |||
| if not isinstance(variable_value, str): | |||
| raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.") | |||
| if variable_name not in variables: | |||
| variables[variable_name] = "" | |||
| variables[variable_name] += variable_value | |||
| yield StreamChunkEvent( | |||
| selector=[node_id, variable_name], | |||
| chunk=variable_value, | |||
| is_final=False, | |||
| ) | |||
| else: | |||
| variables[variable_name] = variable_value | |||
| elif message.type == ToolInvokeMessage.MessageType.FILE: | |||
| assert message.meta is not None | |||
| assert isinstance(message.meta, dict) | |||
| # Validate that meta contains a 'file' key | |||
| if "file" not in message.meta: | |||
| raise ToolNodeError("File message is missing 'file' key in meta") | |||
| # Validate that the file is an instance of File | |||
| if not isinstance(message.meta["file"], File): | |||
| raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") | |||
| files.append(message.meta["file"]) | |||
| elif message.type == ToolInvokeMessage.MessageType.LOG: | |||
| assert isinstance(message.message, ToolInvokeMessage.LogMessage) | |||
| if message.message.metadata: | |||
| icon = tool_info.get("icon", "") | |||
| dict_metadata = dict(message.message.metadata) | |||
| if dict_metadata.get("provider"): | |||
| manager = PluginInstaller() | |||
| plugins = manager.list_plugins(tenant_id) | |||
| try: | |||
| current_plugin = next( | |||
| plugin | |||
| for plugin in plugins | |||
| if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] | |||
| ) | |||
| icon = current_plugin.declaration.icon | |||
| except StopIteration: | |||
| pass | |||
| icon_dark = None | |||
| try: | |||
| builtin_tool = next( | |||
| provider | |||
| for provider in BuiltinToolManageService.list_builtin_tools( | |||
| user_id, | |||
| tenant_id, | |||
| ) | |||
| if provider.name == dict_metadata["provider"] | |||
| ) | |||
| icon = builtin_tool.icon | |||
| icon_dark = builtin_tool.icon_dark | |||
| except StopIteration: | |||
| pass | |||
| dict_metadata["icon"] = icon | |||
| dict_metadata["icon_dark"] = icon_dark | |||
| message.message.metadata = dict_metadata | |||
| # Add agent_logs to outputs['json'] to ensure frontend can access thinking process | |||
| json_output: list[dict[str, Any]] = [] | |||
| # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] | |||
| if json: | |||
| json_output.extend(json) | |||
| else: | |||
| json_output.append({"data": []}) | |||
| # Send final chunk events for all streamed outputs | |||
| # Final chunk for text stream | |||
| yield StreamChunkEvent( | |||
| selector=[self._node_id, "text"], | |||
| chunk="", | |||
| is_final=True, | |||
| ) | |||
| # Final chunks for any streamed variables | |||
| for var_name in variables: | |||
| yield StreamChunkEvent( | |||
| selector=[self._node_id, var_name], | |||
| chunk="", | |||
| is_final=True, | |||
| ) | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, | |||
| metadata={ | |||
| WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, | |||
| }, | |||
| inputs=parameters_for_log, | |||
| ) | |||
| ) | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| *, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: Mapping[str, Any], | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| # Create typed NodeData from dict | |||
| typed_node_data = ToolNodeData.model_validate(node_data) | |||
| result = {} | |||
| for parameter_name in typed_node_data.tool_parameters: | |||
| input = typed_node_data.tool_parameters[parameter_name] | |||
| if input.type == "mixed": | |||
| assert isinstance(input.value, str) | |||
| selectors = VariableTemplateParser(input.value).extract_variable_selectors() | |||
| for selector in selectors: | |||
| result[selector.variable] = selector.value_selector | |||
| elif input.type == "variable": | |||
| result[parameter_name] = input.value | |||
| elif input.type == "constant": | |||
| pass | |||
| result = {node_id + "." + key: value for key, value in result.items()} | |||
| return result | |||
| def _get_error_strategy(self) -> Optional[ErrorStrategy]: | |||
| return self._node_data.error_strategy | |||
| def _get_retry_config(self) -> RetryConfig: | |||
| return self._node_data.retry_config | |||
| def _get_title(self) -> str: | |||
| return self._node_data.title | |||
| def _get_description(self) -> Optional[str]: | |||
| return self._node_data.desc | |||
| def _get_default_value_dict(self) -> dict[str, Any]: | |||
| return self._node_data.default_value_dict | |||
| def get_base_node_data(self) -> BaseNodeData: | |||
| return self._node_data | |||
| @property | |||
| def retry(self) -> bool: | |||
| return self._node_data.retry_config.retry_enabled | |||
| @@ -1,445 +0,0 @@ | |||
| import logging | |||
| import time | |||
| import uuid | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Any, Optional | |||
| from configs import dify_config | |||
| from core.app.apps.exc import GenerateTaskStoppedError | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.file.models import File | |||
| from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID | |||
| from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool | |||
| from core.workflow.errors import WorkflowNodeRunFailedError | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_engine import GraphEngine | |||
| from core.workflow.graph_engine.command_channels import InMemoryChannel | |||
| from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer | |||
| from core.workflow.graph_engine.protocols.command_channel import CommandChannel | |||
| from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.base.node import Node | |||
| from core.workflow.nodes.node_factory import DifyNodeFactory | |||
| from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool | |||
| from factories import file_factory | |||
| from models.enums import UserFrom | |||
| from models.workflow import Workflow | |||
| logger = logging.getLogger(__name__) | |||
| class WorkflowEntry: | |||
| def __init__( | |||
| self, | |||
| tenant_id: str, | |||
| app_id: str, | |||
| workflow_id: str, | |||
| graph_config: Mapping[str, Any], | |||
| graph: Graph, | |||
| user_id: str, | |||
| user_from: UserFrom, | |||
| invoke_from: InvokeFrom, | |||
| call_depth: int, | |||
| graph_runtime_state: GraphRuntimeState, | |||
| command_channel: Optional[CommandChannel] = None, | |||
| ) -> None: | |||
| """ | |||
| Init workflow entry | |||
| :param tenant_id: tenant id | |||
| :param app_id: app id | |||
| :param workflow_id: workflow id | |||
| :param workflow_type: workflow type | |||
| :param graph_config: workflow graph config | |||
| :param graph: workflow graph | |||
| :param user_id: user id | |||
| :param user_from: user from | |||
| :param invoke_from: invoke from | |||
| :param call_depth: call depth | |||
| :param variable_pool: variable pool | |||
| :param graph_runtime_state: pre-created graph runtime state | |||
| :param command_channel: command channel for external control (optional, defaults to InMemoryChannel) | |||
| :param thread_pool_id: thread pool id | |||
| """ | |||
| # check call depth | |||
| workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH | |||
| if call_depth > workflow_call_max_depth: | |||
| raise ValueError(f"Max workflow call depth {workflow_call_max_depth} reached.") | |||
| # Use provided command channel or default to InMemoryChannel | |||
| if command_channel is None: | |||
| command_channel = InMemoryChannel() | |||
| self.command_channel = command_channel | |||
| self.graph_engine = GraphEngine( | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| workflow_id=workflow_id, | |||
| user_id=user_id, | |||
| user_from=user_from, | |||
| invoke_from=invoke_from, | |||
| call_depth=call_depth, | |||
| graph=graph, | |||
| graph_config=graph_config, | |||
| graph_runtime_state=graph_runtime_state, | |||
| max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, | |||
| max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, | |||
| command_channel=command_channel, | |||
| ) | |||
| # Add debug logging layer when in debug mode | |||
| if dify_config.DEBUG: | |||
| logger.info("Debug mode enabled - adding DebugLoggingLayer to GraphEngine") | |||
| debug_layer = DebugLoggingLayer( | |||
| level="DEBUG", | |||
| include_inputs=True, | |||
| include_outputs=True, | |||
| include_process_data=False, # Process data can be very verbose | |||
| logger_name=f"GraphEngine.Debug.{workflow_id[:8]}", # Use workflow ID prefix for unique logger | |||
| ) | |||
| self.graph_engine.layer(debug_layer) | |||
| # Add execution limits layer | |||
| limits_layer = ExecutionLimitsLayer( | |||
| max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME | |||
| ) | |||
| self.graph_engine.layer(limits_layer) | |||
| def run(self) -> Generator[GraphEngineEvent, None, None]: | |||
| graph_engine = self.graph_engine | |||
| try: | |||
| # run workflow | |||
| generator = graph_engine.run() | |||
| yield from generator | |||
| except GenerateTaskStoppedError: | |||
| pass | |||
| except Exception as e: | |||
| logger.exception("Unknown Error when workflow entry running") | |||
| yield GraphRunFailedEvent(error=str(e)) | |||
| return | |||
| @classmethod | |||
| def single_step_run( | |||
| cls, | |||
| *, | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user_id: str, | |||
| user_inputs: Mapping[str, Any], | |||
| variable_pool: VariablePool, | |||
| variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | |||
| ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]: | |||
| """ | |||
| Single step run workflow node | |||
| :param workflow: Workflow instance | |||
| :param node_id: node id | |||
| :param user_id: user id | |||
| :param user_inputs: user inputs | |||
| :return: | |||
| """ | |||
| node_config = workflow.get_node_config_by_id(node_id) | |||
| node_config_data = node_config.get("data", {}) | |||
| # Get node class | |||
| node_type = NodeType(node_config_data.get("type")) | |||
| node_version = node_config_data.get("version", "1") | |||
| node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] | |||
| # init graph init params and runtime state | |||
| graph_init_params = GraphInitParams( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=workflow.app_id, | |||
| workflow_id=workflow.id, | |||
| graph_config=workflow.graph_dict, | |||
| user_id=user_id, | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | |||
| # init node factory | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=graph_init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| # init graph | |||
| graph = Graph.init(graph_config=workflow.graph_dict, node_factory=node_factory) | |||
| # init workflow run state | |||
| node = node_cls( | |||
| id=str(uuid.uuid4()), | |||
| config=node_config, | |||
| graph_init_params=graph_init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| node.init_node_data(node_config_data) | |||
| try: | |||
| # variable selector to variable mapping | |||
| variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( | |||
| graph_config=workflow.graph_dict, config=node_config | |||
| ) | |||
| except NotImplementedError: | |||
| variable_mapping = {} | |||
| # Loading missing variable from draft var here, and set it into | |||
| # variable_pool. | |||
| load_into_variable_pool( | |||
| variable_loader=variable_loader, | |||
| variable_pool=variable_pool, | |||
| variable_mapping=variable_mapping, | |||
| user_inputs=user_inputs, | |||
| ) | |||
| cls.mapping_user_inputs_to_variable_pool( | |||
| variable_mapping=variable_mapping, | |||
| user_inputs=user_inputs, | |||
| variable_pool=variable_pool, | |||
| tenant_id=workflow.tenant_id, | |||
| ) | |||
| try: | |||
| # run node | |||
| generator = node.run() | |||
| except Exception as e: | |||
| logger.exception( | |||
| "error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s", | |||
| workflow.id, | |||
| node.id, | |||
| node.node_type, | |||
| node.version(), | |||
| ) | |||
| raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) | |||
| return node, generator | |||
| @staticmethod | |||
| def _create_single_node_graph( | |||
| node_id: str, | |||
| node_data: dict[str, Any], | |||
| node_width: int = 114, | |||
| node_height: int = 514, | |||
| ) -> dict[str, Any]: | |||
| """ | |||
| Create a minimal graph structure for testing a single node in isolation. | |||
| :param node_id: ID of the target node | |||
| :param node_data: configuration data for the target node | |||
| :param node_width: width for UI layout (default: 200) | |||
| :param node_height: height for UI layout (default: 100) | |||
| :return: graph dictionary with start node and target node | |||
| """ | |||
| node_config = { | |||
| "id": node_id, | |||
| "width": node_width, | |||
| "height": node_height, | |||
| "type": "custom", | |||
| "data": node_data, | |||
| } | |||
| start_node_config = { | |||
| "id": "start", | |||
| "width": node_width, | |||
| "height": node_height, | |||
| "type": "custom", | |||
| "data": { | |||
| "type": NodeType.START.value, | |||
| "title": "Start", | |||
| "desc": "Start", | |||
| }, | |||
| } | |||
| return { | |||
| "nodes": [start_node_config, node_config], | |||
| "edges": [ | |||
| { | |||
| "source": "start", | |||
| "target": node_id, | |||
| "sourceHandle": "source", | |||
| "targetHandle": "target", | |||
| } | |||
| ], | |||
| } | |||
| @classmethod | |||
| def run_free_node( | |||
| cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any] | |||
| ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]: | |||
| """ | |||
| Run free node | |||
| NOTE: only parameter_extractor/question_classifier are supported | |||
| :param node_data: node data | |||
| :param node_id: node id | |||
| :param tenant_id: tenant id | |||
| :param user_id: user id | |||
| :param user_inputs: user inputs | |||
| :return: | |||
| """ | |||
| # Create a minimal graph for single node execution | |||
| graph_dict = cls._create_single_node_graph(node_id, node_data) | |||
| node_type = NodeType(node_data.get("type", "")) | |||
| if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}: | |||
| raise ValueError(f"Node type {node_type} not supported") | |||
| node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"] | |||
| if not node_cls: | |||
| raise ValueError(f"Node class not found for node type {node_type}") | |||
| # init variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables=SystemVariable.empty(), | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| # init graph init params and runtime state | |||
| graph_init_params = GraphInitParams( | |||
| tenant_id=tenant_id, | |||
| app_id="", | |||
| workflow_id="", | |||
| graph_config=graph_dict, | |||
| user_id=user_id, | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | |||
| # init node factory | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=graph_init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| # init graph | |||
| graph = Graph.init(graph_config=graph_dict, node_factory=node_factory) | |||
| node_cls = cast(type[Node], node_cls) | |||
| # init workflow run state | |||
| node_config = { | |||
| "id": node_id, | |||
| "data": node_data, | |||
| } | |||
| node: Node = node_cls( | |||
| id=str(uuid.uuid4()), | |||
| config=node_config, | |||
| graph_init_params=graph_init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| node.init_node_data(node_data) | |||
| try: | |||
| # variable selector to variable mapping | |||
| try: | |||
| variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( | |||
| graph_config=graph_dict, config=node_config | |||
| ) | |||
| except NotImplementedError: | |||
| variable_mapping = {} | |||
| cls.mapping_user_inputs_to_variable_pool( | |||
| variable_mapping=variable_mapping, | |||
| user_inputs=user_inputs, | |||
| variable_pool=variable_pool, | |||
| tenant_id=tenant_id, | |||
| ) | |||
| # run node | |||
| generator = node.run() | |||
| return node, generator | |||
| except Exception as e: | |||
| logger.exception( | |||
| "error while running node, node_id=%s, node_type=%s, node_version=%s", | |||
| node.id, | |||
| node.node_type, | |||
| node.version(), | |||
| ) | |||
| raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) | |||
| @staticmethod | |||
| def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: | |||
| # NOTE(QuantumGhost): Avoid using this function in new code. | |||
| # Keep values structured as long as possible and only convert to dict | |||
| # immediately before serialization (e.g., JSON serialization) to maintain | |||
| # data integrity and type information. | |||
| result = WorkflowEntry._handle_special_values(value) | |||
| return result if isinstance(result, Mapping) or result is None else dict(result) | |||
| @staticmethod | |||
| def _handle_special_values(value: Any) -> Any: | |||
| if value is None: | |||
| return value | |||
| if isinstance(value, dict): | |||
| res = {} | |||
| for k, v in value.items(): | |||
| res[k] = WorkflowEntry._handle_special_values(v) | |||
| return res | |||
| if isinstance(value, list): | |||
| res_list = [] | |||
| for item in value: | |||
| res_list.append(WorkflowEntry._handle_special_values(item)) | |||
| return res_list | |||
| if isinstance(value, File): | |||
| return value.to_dict() | |||
| return value | |||
| @classmethod | |||
| def mapping_user_inputs_to_variable_pool( | |||
| cls, | |||
| *, | |||
| variable_mapping: Mapping[str, Sequence[str]], | |||
| user_inputs: Mapping[str, Any], | |||
| variable_pool: VariablePool, | |||
| tenant_id: str, | |||
| ) -> None: | |||
| # NOTE(QuantumGhost): This logic should remain synchronized with | |||
| # the implementation of `load_into_variable_pool`, specifically the logic about | |||
| # variable existence checking. | |||
| # WARNING(QuantumGhost): The semantics of this method are not clearly defined, | |||
| # and multiple parts of the codebase depend on its current behavior. | |||
| # Modify with caution. | |||
| for node_variable, variable_selector in variable_mapping.items(): | |||
| # fetch node id and variable key from node_variable | |||
| node_variable_list = node_variable.split(".") | |||
| if len(node_variable_list) < 1: | |||
| raise ValueError(f"Invalid node variable {node_variable}") | |||
| node_variable_key = ".".join(node_variable_list[1:]) | |||
| if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get( | |||
| variable_selector | |||
| ): | |||
| raise ValueError(f"Variable key {node_variable} not found in user inputs.") | |||
| # environment variable already exist in variable pool, not from user inputs | |||
| if variable_pool.get(variable_selector): | |||
| continue | |||
| # fetch variable node id from variable selector | |||
| variable_node_id = variable_selector[0] | |||
| variable_key_list = variable_selector[1:] | |||
| variable_key_list = list(variable_key_list) | |||
| # get input value | |||
| input_value = user_inputs.get(node_variable) | |||
| if not input_value: | |||
| input_value = user_inputs.get(node_variable_key) | |||
| if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value: | |||
| input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id) | |||
| if ( | |||
| isinstance(input_value, list) | |||
| and all(isinstance(item, dict) for item in input_value) | |||
| and all("type" in item and "transfer_method" in item for item in input_value) | |||
| ): | |||
| input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id) | |||
| # append variable and value to variable pool | |||
| if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID: | |||
| variable_pool.add([variable_node_id] + variable_key_list, input_value) | |||
| @@ -1,390 +0,0 @@ | |||
| import time | |||
| import uuid | |||
| from os import getenv | |||
| import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool | |||
| from core.workflow.enums import WorkflowNodeExecutionStatus | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.node_events import NodeRunResult | |||
| from core.workflow.nodes.code.code_node import CodeNode | |||
| from core.workflow.nodes.node_factory import DifyNodeFactory | |||
| from core.workflow.system_variable import SystemVariable | |||
| from models.enums import UserFrom | |||
| from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock | |||
| CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) | |||
| def init_code_node(code_config: dict): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-code-target", | |||
| "source": "start", | |||
| "target": "code", | |||
| }, | |||
| ], | |||
| "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config], | |||
| } | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables=SystemVariable(user_id="aaa", files=[]), | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| conversation_variables=[], | |||
| ) | |||
| variable_pool.add(["code", "args1"], 1) | |||
| variable_pool.add(["code", "args2"], 2) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | |||
| # Create node factory | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory) | |||
| node = CodeNode( | |||
| id=str(uuid.uuid4()), | |||
| config=code_config, | |||
| graph_init_params=init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| # Initialize node data | |||
| if "data" in code_config: | |||
| node.init_node_data(code_config["data"]) | |||
| return node | |||
| @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) | |||
| def test_execute_code(setup_code_executor_mock): | |||
| code = """ | |||
| def main(args1: int, args2: int) -> dict: | |||
| return { | |||
| "result": args1 + args2, | |||
| } | |||
| """ | |||
| # trim first 4 spaces at the beginning of each line | |||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||
| code_config = { | |||
| "id": "code", | |||
| "data": { | |||
| "outputs": { | |||
| "result": { | |||
| "type": "number", | |||
| }, | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| } | |||
| node = init_code_node(code_config) | |||
| node.graph_runtime_state.variable_pool.add(["1", "args1"], 1) | |||
| node.graph_runtime_state.variable_pool.add(["1", "args2"], 2) | |||
| # execute node | |||
| result = node._run() | |||
| assert isinstance(result, NodeRunResult) | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert result.outputs["result"] == 3 | |||
| assert result.error == "" | |||
| @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) | |||
| def test_execute_code_output_validator(setup_code_executor_mock): | |||
| code = """ | |||
| def main(args1: int, args2: int) -> dict: | |||
| return { | |||
| "result": args1 + args2, | |||
| } | |||
| """ | |||
| # trim first 4 spaces at the beginning of each line | |||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||
| code_config = { | |||
| "id": "code", | |||
| "data": { | |||
| "outputs": { | |||
| "result": { | |||
| "type": "string", | |||
| }, | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| } | |||
| node = init_code_node(code_config) | |||
| node.graph_runtime_state.variable_pool.add(["1", "args1"], 1) | |||
| node.graph_runtime_state.variable_pool.add(["1", "args2"], 2) | |||
| # execute node | |||
| result = node._run() | |||
| assert isinstance(result, NodeRunResult) | |||
| assert result.status == WorkflowNodeExecutionStatus.FAILED | |||
| assert result.error == "Output variable `result` must be a string" | |||
| def test_execute_code_output_validator_depth(): | |||
| code = """ | |||
| def main(args1: int, args2: int) -> dict: | |||
| return { | |||
| "result": { | |||
| "result": args1 + args2, | |||
| } | |||
| } | |||
| """ | |||
| # trim first 4 spaces at the beginning of each line | |||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||
| code_config = { | |||
| "id": "code", | |||
| "data": { | |||
| "outputs": { | |||
| "string_validator": { | |||
| "type": "string", | |||
| }, | |||
| "number_validator": { | |||
| "type": "number", | |||
| }, | |||
| "number_array_validator": { | |||
| "type": "array[number]", | |||
| }, | |||
| "string_array_validator": { | |||
| "type": "array[string]", | |||
| }, | |||
| "object_validator": { | |||
| "type": "object", | |||
| "children": { | |||
| "result": { | |||
| "type": "number", | |||
| }, | |||
| "depth": { | |||
| "type": "object", | |||
| "children": { | |||
| "depth": { | |||
| "type": "object", | |||
| "children": { | |||
| "depth": { | |||
| "type": "number", | |||
| } | |||
| }, | |||
| } | |||
| }, | |||
| }, | |||
| }, | |||
| }, | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| } | |||
| node = init_code_node(code_config) | |||
| # construct result | |||
| result = { | |||
| "number_validator": 1, | |||
| "string_validator": "1", | |||
| "number_array_validator": [1, 2, 3, 3.333], | |||
| "string_array_validator": ["1", "2", "3"], | |||
| "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, | |||
| } | |||
| # validate | |||
| node._transform_result(result, node._node_data.outputs) | |||
| # construct result | |||
| result = { | |||
| "number_validator": "1", | |||
| "string_validator": 1, | |||
| "number_array_validator": ["1", "2", "3", "3.333"], | |||
| "string_array_validator": [1, 2, 3], | |||
| "object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}}, | |||
| } | |||
| # validate | |||
| with pytest.raises(ValueError): | |||
| node._transform_result(result, node._node_data.outputs) | |||
| # construct result | |||
| result = { | |||
| "number_validator": 1, | |||
| "string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1", | |||
| "number_array_validator": [1, 2, 3, 3.333], | |||
| "string_array_validator": ["1", "2", "3"], | |||
| "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, | |||
| } | |||
| # validate | |||
| with pytest.raises(ValueError): | |||
| node._transform_result(result, node._node_data.outputs) | |||
| # construct result | |||
| result = { | |||
| "number_validator": 1, | |||
| "string_validator": "1", | |||
| "number_array_validator": [1, 2, 3, 3.333] * 2000, | |||
| "string_array_validator": ["1", "2", "3"], | |||
| "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, | |||
| } | |||
| # validate | |||
| with pytest.raises(ValueError): | |||
| node._transform_result(result, node._node_data.outputs) | |||
| def test_execute_code_output_object_list(): | |||
| code = """ | |||
| def main(args1: int, args2: int) -> dict: | |||
| return { | |||
| "result": { | |||
| "result": args1 + args2, | |||
| } | |||
| } | |||
| """ | |||
| # trim first 4 spaces at the beginning of each line | |||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||
| code_config = { | |||
| "id": "code", | |||
| "data": { | |||
| "outputs": { | |||
| "object_list": { | |||
| "type": "array[object]", | |||
| }, | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| } | |||
| node = init_code_node(code_config) | |||
| # construct result | |||
| result = { | |||
| "object_list": [ | |||
| { | |||
| "result": 1, | |||
| }, | |||
| { | |||
| "result": 2, | |||
| }, | |||
| { | |||
| "result": [1, 2, 3], | |||
| }, | |||
| ] | |||
| } | |||
| # validate | |||
| node._transform_result(result, node._node_data.outputs) | |||
| # construct result | |||
| result = { | |||
| "object_list": [ | |||
| { | |||
| "result": 1, | |||
| }, | |||
| { | |||
| "result": 2, | |||
| }, | |||
| { | |||
| "result": [1, 2, 3], | |||
| }, | |||
| 1, | |||
| ] | |||
| } | |||
| # validate | |||
| with pytest.raises(ValueError): | |||
| node._transform_result(result, node._node_data.outputs) | |||
| def test_execute_code_scientific_notation(): | |||
| code = """ | |||
| def main() -> dict: | |||
| return { | |||
| "result": -8.0E-5 | |||
| } | |||
| """ | |||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||
| code_config = { | |||
| "id": "code", | |||
| "data": { | |||
| "outputs": { | |||
| "result": { | |||
| "type": "number", | |||
| }, | |||
| }, | |||
| "title": "123", | |||
| "variables": [], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| } | |||
| node = init_code_node(code_config) | |||
| # execute node | |||
| result = node._run() | |||
| assert isinstance(result, NodeRunResult) | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||