Quellcode durchsuchen

fix: iteration and loop node single step run (#26036)

tags/1.9.0
Novice vor 1 Monat
Ursprung
Commit
d823da18db
Es ist kein Account mit der E-Mail-Adresse des Committers verbunden

+ 5
- 22
api/core/app/apps/advanced_chat/app_runner.py Datei anzeigen

@@ -79,29 +79,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")

if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs

+ 5
- 0
api/core/app/apps/pipeline/pipeline_generator.py Datei anzeigen

@@ -427,6 +427,9 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
workflow_execution_id=str(uuid.uuid4()),
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@@ -465,6 +468,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)

def single_loop_generate(
@@ -559,6 +563,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)

def _generate_worker(

+ 5
- 22
api/core/app/apps/pipeline/pipeline_runner.py Datei anzeigen

@@ -86,29 +86,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
db.session.close()

# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs

+ 5
- 23
api/core/app/apps/workflow/app_runner.py Datei anzeigen

@@ -51,30 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)

# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs

+ 111
- 120
api/core/app/apps/workflow_app_runner.py Datei anzeigen

@@ -1,3 +1,4 @@
import time
from collections.abc import Mapping
from typing import Any, cast

@@ -119,15 +120,81 @@ class WorkflowBasedAppRunner:

return graph

def _get_graph_and_variable_pool_of_single_iteration(
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
(either single iteration or single loop).

Args:
workflow: The workflow instance
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise

Returns:
A tuple containing (graph, variable_pool, graph_runtime_state)

Raises:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
start_at=time.time(),
)

# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")

# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state

def _get_graph_and_variable_pool_for_single_node_run(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
Get graph and variable pool for single node execution (iteration or loop).

Args:
workflow: The workflow instance
node_id: The node ID to execute
user_inputs: User inputs for the node
graph_runtime_state: The graph runtime state
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
node_type_label: Label for error messages ('iteration' or 'loop')

Returns:
A tuple containing (graph, variable_pool)
"""
# fetch workflow graph
graph_config = workflow.graph_dict
@@ -145,18 +212,22 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")

# filter nodes only in iteration
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
]

graph_config["nodes"] = node_configs

node_ids = [node.get("id") for node in node_configs]

# filter edges only in iteration
# filter edges only in the specified node type
edge_configs = [
edge
for edge in graph_config.get("edges", [])
@@ -190,30 +261,26 @@ class WorkflowBasedAppRunner:
raise ValueError("graph not found in workflow")

# fetch node config from node id
iteration_node_config = None
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
iteration_node_config = node
target_node_config = node
break

if not iteration_node_config:
raise ValueError("iteration node id not found in workflow graph")
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")

# Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_version = iteration_node_config.get("data", {}).get("version", "1")
node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = target_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]

# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
# Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = graph_runtime_state.variable_pool

try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=iteration_node_config
graph_config=workflow.graph_dict, config=target_node_config
)
except NotImplementedError:
variable_mapping = {}
@@ -234,120 +301,44 @@ class WorkflowBasedAppRunner:

return graph, variable_pool

def _get_graph_and_variable_pool_of_single_loop(
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single loop
Get variable pool of single iteration
"""
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError("workflow graph not found")

graph_config = cast(dict[str, Any], graph_config)

if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")

if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")

if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")

# filter nodes only in loop
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
]

graph_config["nodes"] = node_configs

node_ids = [node.get("id") for node in node_configs]

# filter edges only in loop
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]

graph_config["edges"] = edge_configs

# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)

node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)

if not graph:
raise ValueError("graph not found in workflow")

# fetch node config from node id
loop_node_config = None
for node in node_configs:
if node.get("id") == node_id:
loop_node_config = node
break

if not loop_node_config:
raise ValueError("loop node id not found in workflow graph")

# Get node class
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
node_version = loop_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]

# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)

try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=loop_node_config
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)

WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
def _get_graph_and_variable_pool_of_single_loop(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single loop
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
)

return graph, variable_pool

def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event

+ 10
- 39
api/core/workflow/nodes/iteration/iteration_node.py Datei anzeigen

@@ -372,43 +372,16 @@ class IterationNode(Node):
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector,
}
iteration_node_ids = set()

# init graph
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory

# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)

# Create minimal GraphRuntimeState for static analysis
from core.workflow.entities import VariablePool

graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)

# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)

iteration_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)

if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
in_iteration_node_id = node.get("id")
if in_iteration_node_id:
iteration_node_ids.add(in_iteration_node_id)

# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
@@ -444,9 +417,7 @@ class IterationNode(Node):
variable_mapping.update(sub_node_variable_mapping)

# remove variable out from iteration
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
}
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}

return variable_mapping


+ 37
- 42
api/core/workflow/nodes/loop/loop_node.py Datei anzeigen

@@ -1,3 +1,4 @@
import contextlib
import json
import logging
from collections.abc import Callable, Generator, Mapping, Sequence
@@ -127,11 +128,13 @@ class LoopNode(Node):
try:
reach_break_condition = False
if break_conditions:
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
with contextlib.suppress(ValueError):
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)

if reach_break_condition:
loop_count = 0
cost_tokens = 0
@@ -295,42 +298,11 @@ class LoopNode(Node):

variable_mapping = {}

# init graph
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory

# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)

# Create minimal GraphRuntimeState for static analysis
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)

# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
# Extract loop node IDs statically from graph_config

loop_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)

if not loop_graph:
raise ValueError("loop graph not found")
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)

# Get node configs from graph_config instead of non-existent node_id_config_mapping
# Get node configs from graph_config
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id:
@@ -371,12 +343,35 @@ class LoopNode(Node):
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector

# remove variable out from loop
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
}
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}

return variable_mapping

@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.

This method statically analyzes the graph configuration to find all nodes
that are part of the specified loop, without creating actual node instances.

:param graph_config: the complete graph configuration
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()

# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
loop_node_ids.add(node_id)

return loop_node_ids

@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""

Laden…
Abbrechen
Speichern