ソースを参照

Feat: continue on error (#11458)

Co-authored-by: Novice Lee <novicelee@NovicedeMacBook-Pro.local>
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
tags/0.14.0
Novice 10ヶ月前
コミット
79a710ce98
コミッターのメールアドレスに関連付けられたアカウントが存在しません
31個のファイルの変更1211行の追加80行の削除
  1. 27
    1
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  2. 3
    1
      api/core/app/apps/workflow/app_queue_manager.py
  3. 61
    7
      api/core/app/apps/workflow/generate_task_pipeline.py
  4. 39
    1
      api/core/app/apps/workflow_app_runner.py
  5. 44
    0
      api/core/app/entities/queue_entities.py
  6. 1
    0
      api/core/app/entities/task_entities.py
  7. 68
    6
      api/core/app/task_pipeline/workflow_cycle_manage.py
  8. 7
    1
      api/core/helper/ssrf_proxy.py
  9. 3
    0
      api/core/workflow/callbacks/workflow_logging_callback.py
  10. 2
    0
      api/core/workflow/entities/node_entities.py
  11. 10
    0
      api/core/workflow/graph_engine/entities/event.py
  12. 18
    9
      api/core/workflow/graph_engine/entities/graph.py
  13. 9
    1
      api/core/workflow/graph_engine/entities/runtime_route_state.py
  14. 146
    26
      api/core/workflow/graph_engine/graph_engine.py
  15. 13
    8
      api/core/workflow/nodes/answer/answer_stream_generate_router.py
  16. 2
    1
      api/core/workflow/nodes/answer/answer_stream_processor.py
  17. 112
    2
      api/core/workflow/nodes/base/entities.py
  18. 10
    0
      api/core/workflow/nodes/base/exc.py
  19. 11
    5
      api/core/workflow/nodes/base/node.py
  20. 3
    1
      api/core/workflow/nodes/code/code_node.py
  21. 13
    0
      api/core/workflow/nodes/enums.py
  22. 5
    2
      api/core/workflow/nodes/http_request/executor.py
  23. 16
    0
      api/core/workflow/nodes/http_request/node.py
  24. 1
    0
      api/core/workflow/nodes/llm/node.py
  25. 1
    1
      api/core/workflow/nodes/question_classifier/question_classifier_node.py
  26. 2
    0
      api/core/workflow/nodes/tool/tool_node.py
  27. 4
    0
      api/fields/workflow_run_fields.py
  28. 33
    0
      api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py
  29. 6
    1
      api/models/workflow.py
  30. 39
    6
      api/services/workflow_service.py
  31. 502
    0
      api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py

+ 27
- 1
api/core/app/apps/advanced_chat/generate_task_pipeline.py ファイルの表示

@@ -19,6 +19,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@@ -31,6 +32,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@@ -317,7 +319,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc

if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)

response = self._workflow_node_finish_to_stream_response(
@@ -384,6 +386,29 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)

self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")

if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")

workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)

self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
@@ -401,6 +426,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)

yield self._workflow_finish_to_stream_response(

+ 3
- 1
api/core/app/apps/workflow/app_queue_manager.py ファイルの表示

@@ -6,6 +6,7 @@ from core.app.entities.queue_entities import (
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
@@ -34,7 +35,8 @@ class WorkflowAppQueueManager(AppQueueManager):
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent,
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()


+ 61
- 7
api/core/app/apps/workflow/generate_task_pipeline.py ファイルの表示

@@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@@ -26,6 +27,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@@ -276,7 +278,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa

if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)

response = self._workflow_node_finish_to_stream_response(
@@ -345,22 +347,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")

if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")

workflow_run = self._handle_workflow_run_failed(
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
@@ -368,6 +368,60 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")

if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
handle_args = {
"workflow_run": workflow_run,
"start_at": graph_runtime_state.start_at,
"total_tokens": graph_runtime_state.total_tokens,
"total_steps": graph_runtime_state.node_run_steps,
"status": WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
"error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
"conversation_id": None,
"trace_manager": trace_manager,
"exceptions_count": event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
}
workflow_run = self._handle_workflow_run_failed(**handle_args)

# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")

if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
handle_args = {
"workflow_run": workflow_run,
"start_at": graph_runtime_state.start_at,
"total_tokens": graph_runtime_state.total_tokens,
"total_steps": graph_runtime_state.node_run_steps,
"status": WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
"error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
"conversation_id": None,
"trace_manager": trace_manager,
"exceptions_count": event.exceptions_count,
}
workflow_run = self._handle_workflow_run_partial_success(**handle_args)

# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)

+ 39
- 1
api/core/app/apps/workflow_app_runner.py ファイルの表示

@@ -8,6 +8,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@@ -18,6 +19,7 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@@ -25,6 +27,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
@@ -32,6 +35,7 @@ from core.workflow.graph_engine.entities.event import (
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
@@ -176,8 +180,12 @@ class WorkflowBasedAppRunner(AppRunner):
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self._publish_event(
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
@@ -253,6 +261,36 @@ class WorkflowBasedAppRunner(AppRunner):
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunExceptionEvent):
self._publish_event(
QueueNodeExceptionEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(

+ 44
- 0
api/core/app/entities/queue_entities.py ファイルの表示

@@ -25,12 +25,14 @@ class QueueEvent(StrEnum):
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
ITERATION_START = "iteration_start"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
NODE_STARTED = "node_started"
NODE_SUCCEEDED = "node_succeeded"
NODE_FAILED = "node_failed"
NODE_EXCEPTION = "node_exception"
RETRIEVER_RESOURCES = "retriever_resources"
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
@@ -237,6 +239,17 @@ class QueueWorkflowFailedEvent(AppQueueEvent):

event: QueueEvent = QueueEvent.WORKFLOW_FAILED
error: str
exceptions_count: int


class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
"""
QueueWorkflowFailedEvent entity
"""

event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
exceptions_count: int
outputs: Optional[dict[str, Any]] = None


class QueueNodeStartedEvent(AppQueueEvent):
@@ -331,6 +344,37 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
error: str


class QueueNodeExceptionEvent(AppQueueEvent):
"""
QueueNodeExceptionEvent entity
"""

event: QueueEvent = QueueEvent.NODE_EXCEPTION

node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime

inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None

error: str


class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity

+ 1
- 0
api/core/app/entities/task_entities.py ファイルの表示

@@ -213,6 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
created_by: Optional[dict] = None
created_at: int
finished_at: int
exceptions_count: Optional[int] = 0
files: Optional[Sequence[Mapping[str, Any]]] = []

event: StreamEvent = StreamEvent.WORKFLOW_FINISHED

+ 68
- 6
api/core/app/task_pipeline/workflow_cycle_manage.py ファイルの表示

@@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@@ -164,6 +165,55 @@ class WorkflowCycleManage:

return workflow_run

def _handle_workflow_run_partial_success(
self,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)

outputs = WorkflowEntry.handle_special_values(outputs)

workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
db.session.commit()
db.session.refresh(workflow_run)

if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)

db.session.close()

return workflow_run

def _handle_workflow_run_failed(
self,
workflow_run: WorkflowRun,
@@ -174,6 +224,7 @@ class WorkflowCycleManage:
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
exceptions_count: int = 0,
) -> WorkflowRun:
"""
Workflow run failed
@@ -193,7 +244,7 @@ class WorkflowCycleManage:
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
db.session.commit()

running_workflow_node_executions = (
@@ -318,7 +369,7 @@ class WorkflowCycleManage:
return workflow_node_execution

def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
@@ -337,7 +388,11 @@ class WorkflowCycleManage:
)
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.status: (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
),
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
@@ -351,8 +406,11 @@ class WorkflowCycleManage:
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)

workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.status = (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
)
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
@@ -433,6 +491,7 @@ class WorkflowCycleManage:
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
exceptions_count=workflow_run.exceptions_count,
),
)

@@ -483,7 +542,10 @@ class WorkflowCycleManage:

def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:

+ 7
- 1
api/core/helper/ssrf_proxy.py ファイルの表示

@@ -24,6 +24,12 @@ BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]


class MaxRetriesExceededError(Exception):
"""Raised when the maximum number of retries is exceeded."""

pass


def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
@@ -64,7 +70,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if retries <= max_retries:
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))

raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}")
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")


def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):

+ 3
- 0
api/core/workflow/callbacks/workflow_logging_callback.py ファイルの表示

@@ -4,6 +4,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
@@ -39,6 +40,8 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunPartialSucceededEvent):
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent):

+ 2
- 0
api/core/workflow/entities/node_entities.py ファイルの表示

@@ -25,6 +25,7 @@ class NodeRunMetadataKey(StrEnum):
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field


class NodeRunResult(BaseModel):
@@ -43,3 +44,4 @@ class NodeRunResult(BaseModel):
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches

error: Optional[str] = None # error message if status is failed
error_type: Optional[str] = None # error type if status is failed

+ 10
- 0
api/core/workflow/graph_engine/entities/event.py ファイルの表示

@@ -33,6 +33,12 @@ class GraphRunSucceededEvent(BaseGraphEvent):

class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: Optional[int] = Field(description="exception count", default=0)


class GraphRunPartialSucceededEvent(BaseGraphEvent):
exceptions_count: int = Field(..., description="exception count")
outputs: Optional[dict[str, Any]] = None


###########################################
@@ -83,6 +89,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")


class NodeRunExceptionEvent(BaseNodeEvent):
error: str = Field(..., description="error")


class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")


+ 18
- 9
api/core/workflow/graph_engine/entities/graph.py ファイルの表示

@@ -64,13 +64,21 @@ class Graph(BaseModel):
edge_configs = graph_config.get("edges")
if edge_configs is None:
edge_configs = []
# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")

edge_configs = cast(list, edge_configs)
node_configs = cast(list, node_configs)

# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
fail_branch_source_node_id = [
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
]
for edge_config in edge_configs:
source_node_id = edge_config.get("source")
if not source_node_id:
@@ -90,8 +98,16 @@ class Graph(BaseModel):

# parse run condition
run_condition = None
if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle"))
if edge_config.get("sourceHandle"):
if (
edge_config.get("source") in fail_branch_source_node_id
and edge_config.get("sourceHandle") != "fail-branch"
):
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
elif edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
)

graph_edge = GraphEdge(
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
@@ -100,13 +116,6 @@ class Graph(BaseModel):
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)

# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")

node_configs = cast(list, node_configs)

# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}

+ 9
- 1
api/core/workflow/graph_engine/entities/runtime_route_state.py ファイルの表示

@@ -15,6 +15,7 @@ class RouteNodeState(BaseModel):
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
EXCEPTION = "exception"

id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""
@@ -51,7 +52,11 @@ class RouteNodeState(BaseModel):

:param run_result: run result
"""
if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}:
if self.status in {
RouteNodeState.Status.SUCCESS,
RouteNodeState.Status.FAILED,
RouteNodeState.Status.EXCEPTION,
}:
raise Exception(f"Route state {self.id} already finished")

if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
@@ -59,6 +64,9 @@ class RouteNodeState(BaseModel):
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
self.status = RouteNodeState.Status.EXCEPTION
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")


+ 146
- 26
api/core/workflow/graph_engine/graph_engine.py ファイルの表示

@@ -5,21 +5,23 @@ import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from typing import Any, Optional
from typing import Any, Optional, cast

from flask import Flask, current_app

from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
BaseIterationEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
@@ -36,7 +38,9 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
@@ -128,6 +132,7 @@ class GraphEngine:
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
handle_exceptions = []

try:
if self.init_params.workflow_type == WorkflowType.CHAT:
@@ -140,13 +145,17 @@ class GraphEngine:
)

# run graph
generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))

generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions)
)
for item in generator:
try:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.")
yield GraphRunFailedEvent(
error=item.route_node_state.failed_reason or "Unknown error.",
exceptions_count=len(handle_exceptions),
)
return
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
@@ -172,19 +181,24 @@ class GraphEngine:
].strip()
except Exception as e:
logger.exception("Graph run failed")
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
return

# trigger graph run success event
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
# count exceptions to determine partial success
if len(handle_exceptions) > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs
)
else:
# trigger graph run success event
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
self._release_thread()
except GraphRunFailedError as e:
yield GraphRunFailedEvent(error=e.error)
yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions))
self._release_thread()
return
except Exception as e:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
self._release_thread()
raise e

@@ -198,6 +212,7 @@ class GraphEngine:
in_parallel_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]:
parallel_start_node_id = None
if in_parallel_id:
@@ -242,7 +257,7 @@ class GraphEngine:
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
try:
# run node
generator = self._run_node(
@@ -252,6 +267,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)

for item in generator:
@@ -301,7 +317,12 @@ class GraphEngine:

if len(edge_mappings) == 1:
edge = edge_mappings[0]

if (
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and edge.run_condition is None
):
break
if edge.run_condition:
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@@ -334,7 +355,7 @@ class GraphEngine:
if len(sub_edge_mappings) == 0:
continue

edge = sub_edge_mappings[0]
edge = cast(GraphEdge, sub_edge_mappings[0])

result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@@ -355,6 +376,7 @@ class GraphEngine:
edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
handle_exceptions=handle_exceptions,
)

for item in parallel_generator:
@@ -369,11 +391,18 @@ class GraphEngine:
break

next_node_id = final_node_id
elif (
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and node_instance.should_continue_on_error
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
):
break
else:
parallel_generator = self._run_parallel_branches(
edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
handle_exceptions=handle_exceptions,
)

for item in parallel_generator:
@@ -395,6 +424,7 @@ class GraphEngine:
edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
@@ -438,6 +468,7 @@ class GraphEngine:
"parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id,
"parent_parallel_start_node_id": parallel_start_node_id,
"handle_exceptions": handle_exceptions,
},
)

@@ -481,6 +512,7 @@ class GraphEngine:
parallel_start_node_id: str,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> None:
"""
Run parallel nodes
@@ -502,6 +534,7 @@ class GraphEngine:
in_parallel_id=parallel_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)

for item in generator:
@@ -548,6 +581,7 @@ class GraphEngine:
parallel_start_node_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]:
"""
Run node
@@ -587,19 +621,55 @@ class GraphEngine:
route_node_state.set_finished(run_result=run_result)

if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)

elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
@@ -735,6 +805,56 @@ class GraphEngine:
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
return new_instance

def _handle_continue_on_error(
self,
node_instance: BaseNode[BaseNodeData],
error_result: NodeRunResult,
variable_pool: VariablePool,
handle_exceptions: list[str] = [],
) -> NodeRunResult:
"""
handle continue on error when self._should_continue_on_error is True


:param error_result (NodeRunResult): error run result
:param variable_pool (VariablePool): variable pool
:return: excption run result
"""
# add error message and error type to variable pool
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions
handle_exceptions.append(error_result.error)
node_error_args = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": error_result.error,
"inputs": error_result.inputs,
"metadata": {
NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
},
}

if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
return NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
if self.graph.edge_mapping.get(node_instance.node_id):
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
return NodeRunResult(
**node_error_args,
outputs={
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
return error_result


class GraphRunFailedError(Exception):
def __init__(self, error: str):

+ 13
- 8
api/core/workflow/nodes/answer/answer_stream_generate_router.py ファイルの表示

@@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import (
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser


@@ -148,13 +148,18 @@ class AnswerStreamGeneratorRouter:
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.VARIABLE_ASSIGNER,
}:
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (
source_node_type
in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.VARIABLE_ASSIGNER,
}
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(

+ 2
- 1
api/core/workflow/nodes/answer/answer_stream_processor.py ファイルの表示

@@ -6,6 +6,7 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunExceptionEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@@ -50,7 +51,7 @@ class AnswerStreamProcessor(StreamProcessor):

for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent):
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished

+ 112
- 2
api/core/workflow/nodes/base/entities.py ファイルの表示

@@ -1,14 +1,124 @@
import json
from abc import ABC
from typing import Optional
from enum import StrEnum
from typing import Any, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, model_validator

from core.workflow.nodes.base.exc import DefaultValueTypeError
from core.workflow.nodes.enums import ErrorStrategy


class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"


NumberType = Union[int, float]


class DefaultValue(BaseModel):
value: Any
type: DefaultValueType
key: str

@staticmethod
def _parse_json(value: str) -> Any:
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")

@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)

@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")

@model_validator(mode="after")
def validate_value_type(self) -> "DefaultValue":
if self.type is None:
raise DefaultValueTypeError("type field is required")

# Type validation configuration
type_validators = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}

validator = type_validators.get(self.type)
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")

# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)

# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")

# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")

return self


class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"

@property
def default_value_dict(self):
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}


class BaseIterationNodeData(BaseNodeData):
start_node_id: Optional[str] = None

+ 10
- 0
api/core/workflow/nodes/base/exc.py ファイルの表示

@@ -0,0 +1,10 @@
class BaseNodeError(Exception):
"""Base class for node errors."""

pass


class DefaultValueTypeError(BaseNodeError):
"""Raised when the default value type is invalid."""

pass

+ 11
- 5
api/core/workflow/nodes/base/node.py ファイルの表示

@@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast

from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus

@@ -72,10 +72,7 @@ class BaseNode(Generic[GenericNodeData]):
result = self._run()
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")

if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
@@ -137,3 +134,12 @@ class BaseNode(Generic[GenericNodeData]):
:return:
"""
return self._node_type

@property
def should_continue_on_error(self) -> bool:
"""judge if should continue on error

Returns:
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE

+ 3
- 1
api/core/workflow/nodes/code/code_node.py ファイルの表示

@@ -61,7 +61,9 @@ class CodeNode(BaseNode[CodeNodeData]):
# Transform result
result = self._transform_result(result, self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
)

return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)


+ 13
- 0
api/core/workflow/nodes/enums.py ファイルの表示

@@ -22,3 +22,16 @@ class NodeType(StrEnum):
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"


class ErrorStrategy(StrEnum):
FAIL_BRANCH = "fail-branch"
DEFAULT_VALUE = "default-value"


class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"


CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]

+ 5
- 2
api/core/workflow/nodes/http_request/executor.py ファイルの表示

@@ -21,6 +21,7 @@ from .entities import (
from .exc import (
AuthorizationConfigError,
FileFetchError,
HttpRequestNodeError,
InvalidHttpMethodError,
ResponseSizeError,
)
@@ -208,8 +209,10 @@ class Executor:
"follow_redirects": True,
}
# request_args = {k: v for k, v in request_args.items() if v is not None}

response = getattr(ssrf_proxy, self.method)(**request_args)
try:
response = getattr(ssrf_proxy, self.method)(**request_args)
except ssrf_proxy.MaxRetriesExceededError as e:
raise HttpRequestNodeError(str(e))
return response

def invoke(self) -> Response:

+ 16
- 0
api/core/workflow/nodes/http_request/node.py ファイルの表示

@@ -65,6 +65,21 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):

response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and self.should_continue_on_error:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
"status_code": response.status_code,
"body": response.text if not files else "",
"headers": response.headers,
"files": files,
},
process_data={
"request": http_executor.to_log(),
},
error=f"Request failed with status code {response.status_code}",
error_type="HTTPResponseCodeError",
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
@@ -83,6 +98,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
process_data=process_data,
error_type=type(e).__name__,
)

@staticmethod

+ 1
- 0
api/core/workflow/nodes/llm/node.py ファイルの表示

@@ -193,6 +193,7 @@ class LLMNode(BaseNode[LLMNodeData]):
error=str(e),
inputs=node_inputs,
process_data=process_data,
error_type=type(e).__name__,
)
)
return

+ 1
- 1
api/core/workflow/nodes/question_classifier/question_classifier_node.py ファイルの表示

@@ -139,7 +139,7 @@ class QuestionClassifierNode(LLMNode):
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
}
outputs = {"class_name": category_name}
outputs = {"class_name": category_name, "class_id": category_id}

return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,

+ 2
- 0
api/core/workflow/nodes/tool/tool_node.py ファイルの表示

@@ -56,6 +56,7 @@ class ToolNode(BaseNode[ToolNodeData]):
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to get tool runtime: {str(e)}",
error_type=type(e).__name__,
)

# get parameters
@@ -89,6 +90,7 @@ class ToolNode(BaseNode[ToolNodeData]):
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to invoke tool: {str(e)}",
error_type=type(e).__name__,
)

# convert tool messages

+ 4
- 0
api/fields/workflow_run_fields.py ファイルの表示

@@ -14,6 +14,7 @@ workflow_run_for_log_fields = {
"total_steps": fields.Integer,
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
}

workflow_run_for_list_fields = {
@@ -27,6 +28,7 @@ workflow_run_for_list_fields = {
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
}

advanced_chat_workflow_run_for_list_fields = {
@@ -42,6 +44,7 @@ advanced_chat_workflow_run_for_list_fields = {
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
}

advanced_chat_workflow_run_pagination_fields = {
@@ -73,6 +76,7 @@ workflow_run_detail_fields = {
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
}

workflow_run_node_execution_fields = {

+ 33
- 0
api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py ファイルの表示

@@ -0,0 +1,33 @@
"""add exceptions_count field to WorkflowRun model

Revision ID: cf8f4fc45278
Revises: 01d6889832f7
Create Date: 2024-11-28 05:53:21.576178

"""
from alembic import op
import models as models
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'cf8f4fc45278'
down_revision = '01d6889832f7'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.add_column(sa.Column('exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.drop_column('exceptions_count')

# ### end Alembic commands ###

+ 6
- 1
api/models/workflow.py ファイルの表示

@@ -325,6 +325,7 @@ class WorkflowRunStatus(StrEnum):
SUCCEEDED = "succeeded"
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCESSED = "partial-succeeded"

@classmethod
def value_of(cls, value: str) -> "WorkflowRunStatus":
@@ -395,7 +396,7 @@ class WorkflowRun(db.Model):
version = db.Column(db.String(255), nullable=False)
graph = db.Column(db.Text)
inputs = db.Column(db.Text)
status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped
status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[str] = mapped_column(sa.Text, default="{}")
error = db.Column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
@@ -405,6 +406,7 @@ class WorkflowRun(db.Model):
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
finished_at = db.Column(db.DateTime)
exceptions_count = db.Column(db.Integer, server_default=db.text("0"))

@property
def created_by_account(self):
@@ -464,6 +466,7 @@ class WorkflowRun(db.Model):
"created_by": self.created_by,
"created_at": self.created_at,
"finished_at": self.finished_at,
"exceptions_count": self.exceptions_count,
}

@classmethod
@@ -489,6 +492,7 @@ class WorkflowRun(db.Model):
created_by=data.get("created_by"),
created_at=data.get("created_at"),
finished_at=data.get("finished_at"),
exceptions_count=data.get("exceptions_count"),
)


@@ -522,6 +526,7 @@ class WorkflowNodeExecutionStatus(Enum):
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"

@classmethod
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":

+ 39
- 6
api/services/workflow_service.py ファイルの表示

@@ -2,7 +2,7 @@ import json
import time
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import Optional
from typing import Optional, cast

from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@@ -11,6 +11,9 @@ from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
@@ -225,7 +228,7 @@ class WorkflowService:
user_inputs=user_inputs,
user_id=account.id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
node_run_result: NodeRunResult | None = None
for event in generator:
if isinstance(event, RunCompletedEvent):
@@ -237,8 +240,35 @@ class WorkflowService:

if not node_run_result:
raise ValueError("Node run failed with no run result")

run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
node_error_args = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
else:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node_instance = e.node_instance
@@ -260,7 +290,6 @@ class WorkflowService:
workflow_node_execution.created_by = account.id
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)

if run_succeeded and node_run_result:
# create workflow node execution
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
@@ -277,7 +306,11 @@ class WorkflowService:
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
workflow_node_execution.error = node_run_result.error
else:
# create workflow node execution
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value

+ 502
- 0
api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py ファイルの表示

@@ -0,0 +1,502 @@
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunStreamChunkEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine
from models.enums import UserFrom
from models.workflow import WorkflowType


class ContinueOnErrorTestHelper:
@staticmethod
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a code node configuration"""
node = {
"id": "node",
"data": {
"outputs": {"result": {"type": "number"}},
"error_strategy": error_strategy,
"title": "code",
"variables": [],
"code_language": "python3",
"code": "\n".join([line[4:] for line in code.split("\n")]),
"type": "code",
},
}
if default_value:
node["data"]["default_value"] = default_value
return node

@staticmethod
def get_http_node(
error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
):
"""Helper method to create a http node configuration"""
authorization = (
{
"type": "api-key",
"config": {
"type": "basic",
"api_key": "ak-xxx",
"header": "api-key",
},
}
if authorization_success
else {
"type": "api-key",
# missing config field
}
)
node = {
"id": "node",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com",
"authorization": authorization,
"headers": "X-Header:123",
"params": "A:b",
"body": None,
"type": "http-request",
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node

@staticmethod
def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a http node configuration"""
node = {
"id": "node",
"data": {
"type": "http-request",
"title": "HTTP Request",
"desc": "",
"variables": [],
"method": "get",
"url": "https://api.github.com/issues",
"authorization": {"type": "no-auth", "config": None},
"headers": "",
"params": "",
"body": {"type": "none", "data": []},
"timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0},
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node

@staticmethod
def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a tool node configuration"""
node = {
"id": "node",
"data": {
"title": "a",
"desc": "a",
"provider_id": "maths",
"provider_type": "builtin",
"provider_name": "maths",
"tool_name": "eval_expression",
"tool_label": "eval_expression",
"tool_configurations": {},
"tool_parameters": {
"expression": {
"type": "variable",
"value": ["1", "123", "args1"],
}
},
"type": "tool",
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node

@staticmethod
def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a llm node configuration"""
node = {
"id": "node",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_template": [
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node

@staticmethod
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
"""Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config)
variable_pool = {
"system_variables": {
SystemVariableKey.QUERY: "clear",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
"user_inputs": user_inputs or {"uid": "takato"},
}

return GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200,
)


DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]

FAIL_BRANCH_EDGES = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-true-success-target",
"source": "node",
"target": "success",
"sourceHandle": "source",
},
{
"id": "node-false-error-target",
"source": "node",
"target": "error",
"sourceHandle": "fail-branch",
},
]


def test_code_default_value_continue_on_error():
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""

graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_code_node(
error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}]
),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


def test_code_fail_branch_continue_on_error():
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""

graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "node node run successfully"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "node node run failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_code_node(error_code),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events
)


def test_http_node_default_value_continue_on_error():
"""Test HTTP node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value", [{"key": "response", "type": "string", "value": "http node got error response"}]
),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"}
for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


def test_http_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "HTTP request failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_http_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


def test_tool_node_default_value_continue_on_error():
"""Test tool node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_tool_node(
"default-value", [{"key": "result", "type": "string", "value": "default tool result"}]
),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


def test_tool_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "tool execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "tool execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_tool_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


def test_llm_node_default_value_continue_on_error():
"""Test LLM node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_llm_node(
"default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}]
),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


def test_llm_node_fail_branch_continue_on_error():
"""Test LLM node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "LLM request failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


def test_status_code_error_http_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


def test_variable_pool_error_type_variable():
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
list(graph_engine.run())
error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"])
error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"])
assert error_message != None
assert error_type.value == "HTTPResponseCodeError"


def test_no_node_in_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES[:-1],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
ContinueOnErrorTestHelper.get_http_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0

読み込み中…
キャンセル
保存