Преглед изворни кода

fix: iteration and loop node single step run (#26036)

tags/1.9.0
Novice пре 1 месец
родитељ
комит
d823da18db
No account linked to committer's email address

+ 5
- 22
api/core/app/apps/advanced_chat/app_runner.py Прегледај датотеку

if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")


if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

+ 5
- 0
api/core/app/apps/pipeline/pipeline_generator.py Прегледај датотеку

invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
call_depth=0, call_depth=0,
workflow_execution_id=str(uuid.uuid4()), workflow_execution_id=str(uuid.uuid4()),
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
) )
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
variable_loader=var_loader, variable_loader=var_loader,
context=contextvars.copy_context(),
) )


def single_loop_generate( def single_loop_generate(
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
variable_loader=var_loader, variable_loader=var_loader,
context=contextvars.copy_context(),
) )


def _generate_worker( def _generate_worker(

+ 5
- 22
api/core/app/apps/pipeline/pipeline_runner.py Прегледај датотеку

db.session.close() db.session.close()


# if only single iteration run is requested # if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

+ 5
- 23
api/core/app/apps/workflow/app_runner.py Прегледај датотеку

app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config) app_config = cast(WorkflowAppConfig, app_config)


# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

+ 111
- 120
api/core/app/apps/workflow_app_runner.py Прегледај датотеку

import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast




return graph return graph


def _get_graph_and_variable_pool_of_single_iteration(
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
(either single iteration or single loop).

Args:
workflow: The workflow instance
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise

Returns:
A tuple containing (graph, variable_pool, graph_runtime_state)

Raises:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
start_at=time.time(),
)

# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")

# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state

def _get_graph_and_variable_pool_for_single_node_run(
self, self,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
) -> tuple[Graph, VariablePool]: ) -> tuple[Graph, VariablePool]:
""" """
Get variable pool of single iteration
Get graph and variable pool for single node execution (iteration or loop).

Args:
workflow: The workflow instance
node_id: The node ID to execute
user_inputs: User inputs for the node
graph_runtime_state: The graph runtime state
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
node_type_label: Label for error messages ('iteration' or 'loop')

Returns:
A tuple containing (graph, variable_pool)
""" """
# fetch workflow graph # fetch workflow graph
graph_config = workflow.graph_dict graph_config = workflow.graph_dict
if not isinstance(graph_config.get("edges"), list): if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list") raise ValueError("edges in workflow graph must be a list")


# filter nodes only in iteration
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [ node_configs = [
node node
for node in graph_config.get("nodes", []) for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
] ]


graph_config["nodes"] = node_configs graph_config["nodes"] = node_configs


node_ids = [node.get("id") for node in node_configs] node_ids = [node.get("id") for node in node_configs]


# filter edges only in iteration
# filter edges only in the specified node type
edge_configs = [ edge_configs = [
edge edge
for edge in graph_config.get("edges", []) for edge in graph_config.get("edges", [])
raise ValueError("graph not found in workflow") raise ValueError("graph not found in workflow")


# fetch node config from node id # fetch node config from node id
iteration_node_config = None
target_node_config = None
for node in node_configs: for node in node_configs:
if node.get("id") == node_id: if node.get("id") == node_id:
iteration_node_config = node
target_node_config = node
break break


if not iteration_node_config:
raise ValueError("iteration node id not found in workflow graph")
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")


# Get node class # Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_version = iteration_node_config.get("data", {}).get("version", "1")
node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = target_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]


# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
# Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = graph_runtime_state.variable_pool


try: try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=iteration_node_config
graph_config=workflow.graph_dict, config=target_node_config
) )
except NotImplementedError: except NotImplementedError:
variable_mapping = {} variable_mapping = {}


return graph, variable_pool return graph, variable_pool


def _get_graph_and_variable_pool_of_single_loop(
def _get_graph_and_variable_pool_of_single_iteration(
self, self,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]: ) -> tuple[Graph, VariablePool]:
""" """
Get variable pool of single loop
Get variable pool of single iteration
""" """
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError("workflow graph not found")

graph_config = cast(dict[str, Any], graph_config)

if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")

if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")

if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")

# filter nodes only in loop
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
]

graph_config["nodes"] = node_configs

node_ids = [node.get("id") for node in node_configs]

# filter edges only in loop
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]

graph_config["edges"] = edge_configs

# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)

node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)

if not graph:
raise ValueError("graph not found in workflow")

# fetch node config from node id
loop_node_config = None
for node in node_configs:
if node.get("id") == node_id:
loop_node_config = node
break

if not loop_node_config:
raise ValueError("loop node id not found in workflow graph")

# Get node class
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
node_version = loop_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]

# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)

try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=loop_node_config
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs, user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
) )


WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
def _get_graph_and_variable_pool_of_single_loop(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single loop
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs, user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
) )


return graph, variable_pool

def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
""" """
Handle event Handle event

+ 10
- 39
api/core/workflow/nodes/iteration/iteration_node.py Прегледај датотеку

variable_mapping: dict[str, Sequence[str]] = { variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector, f"{node_id}.input_selector": typed_node_data.iterator_selector,
} }
iteration_node_ids = set()


# init graph
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory

# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)

# Create minimal GraphRuntimeState for static analysis
from core.workflow.entities import VariablePool

graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)

# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)

iteration_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)

if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
in_iteration_node_id = node.get("id")
if in_iteration_node_id:
iteration_node_ids.add(in_iteration_node_id)


# Get node configs from graph_config instead of non-existent node_id_config_mapping # Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
variable_mapping.update(sub_node_variable_mapping) variable_mapping.update(sub_node_variable_mapping)


# remove variable out from iteration # remove variable out from iteration
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
}
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}


return variable_mapping return variable_mapping



+ 37
- 42
api/core/workflow/nodes/loop/loop_node.py Прегледај датотеку

import contextlib
import json import json
import logging import logging
from collections.abc import Callable, Generator, Mapping, Sequence from collections.abc import Callable, Generator, Mapping, Sequence
try: try:
reach_break_condition = False reach_break_condition = False
if break_conditions: if break_conditions:
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
with contextlib.suppress(ValueError):
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)

if reach_break_condition: if reach_break_condition:
loop_count = 0 loop_count = 0
cost_tokens = 0 cost_tokens = 0


variable_mapping = {} variable_mapping = {}


# init graph
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory

# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)

# Create minimal GraphRuntimeState for static analysis
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)

# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
# Extract loop node IDs statically from graph_config


loop_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)

if not loop_graph:
raise ValueError("loop graph not found")
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)


# Get node configs from graph_config instead of non-existent node_id_config_mapping
# Get node configs from graph_config
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items(): for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id: if sub_node_config.get("data", {}).get("loop_id") != node_id:
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector variable_mapping[f"{node_id}.{loop_variable.label}"] = selector


# remove variable out from loop # remove variable out from loop
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
}
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}


return variable_mapping return variable_mapping


@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.

This method statically analyzes the graph configuration to find all nodes
that are part of the specified loop, without creating actual node instances.

:param graph_config: the complete graph configuration
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()

# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
loop_node_ids.add(node_id)

return loop_node_ids

@staticmethod @staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value.""" """Get the appropriate segment type for a constant value."""

Loading…
Откажи
Сачувај