Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.4.1
| @@ -26,10 +26,13 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| from services.conversation_service import ConversationService | |||
| from services.errors.message import MessageNotExistsError | |||
| @@ -159,8 +162,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| # Create workflow node execution repository | |||
| # Create repositories | |||
| # | |||
| # Create session factory | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| # Create workflow execution(aka workflow run) repository | |||
| if invoke_from == InvokeFrom.DEBUGGER: | |||
| workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING | |||
| else: | |||
| workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=workflow_triggered_from, | |||
| ) | |||
| # Create workflow node execution repository | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| @@ -173,6 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| user=user, | |||
| invoke_from=invoke_from, | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| conversation=conversation, | |||
| stream=streaming, | |||
| @@ -226,8 +244,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| # Create workflow node execution repository | |||
| # Create repositories | |||
| # | |||
| # Create session factory | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| # Create workflow execution(aka workflow run) repository | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, | |||
| ) | |||
| # Create workflow node execution repository | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| @@ -240,6 +268,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| user=user, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| conversation=None, | |||
| stream=streaming, | |||
| @@ -291,8 +320,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| # Create workflow node execution repository | |||
| # Create repositories | |||
| # | |||
| # Create session factory | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| # Create workflow execution(aka workflow run) repository | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, | |||
| ) | |||
| # Create workflow node execution repository | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| @@ -305,6 +344,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| user=user, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| conversation=None, | |||
| stream=streaming, | |||
| @@ -317,6 +357,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| user: Union[Account, EndUser], | |||
| invoke_from: InvokeFrom, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| conversation: Optional[Conversation] = None, | |||
| stream: bool = True, | |||
| @@ -381,6 +422,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| conversation=conversation, | |||
| message=message, | |||
| user=user, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| stream=stream, | |||
| ) | |||
| @@ -453,6 +495,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| conversation: Conversation, | |||
| message: Message, | |||
| user: Union[Account, EndUser], | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| stream: bool = False, | |||
| ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: | |||
| @@ -476,9 +519,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| conversation=conversation, | |||
| message=message, | |||
| user=user, | |||
| stream=stream, | |||
| dialogue_count=self._dialogue_count, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| stream=stream, | |||
| ) | |||
| try: | |||
| @@ -64,6 +64,7 @@ from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||
| from events.message_event import message_was_created | |||
| @@ -94,6 +95,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| user: Union[Account, EndUser], | |||
| stream: bool, | |||
| dialogue_count: int, | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| ) -> None: | |||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | |||
| @@ -125,6 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||
| }, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| ) | |||
| @@ -294,21 +297,19 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # init workflow run | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start( | |||
| session=session, | |||
| workflow_id=self._workflow_id, | |||
| user_id=self._user_id, | |||
| created_by_role=self._created_by_role, | |||
| ) | |||
| self._workflow_run_id = workflow_run.id | |||
| self._workflow_run_id = workflow_execution.id | |||
| message = self._get_message(session=session) | |||
| if not message: | |||
| raise ValueError(f"Message not found: {self._message_id}") | |||
| message.workflow_run_id = workflow_run.id | |||
| workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| message.workflow_run_id = workflow_execution.id | |||
| workflow_start_resp = self._workflow_cycle_manager.workflow_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| session.commit() | |||
| yield workflow_start_resp | |||
| elif isinstance( | |||
| @@ -319,13 +320,10 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( | |||
| workflow_execution_id=self._workflow_run_id, event=event | |||
| ) | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( | |||
| workflow_run=workflow_run, event=event | |||
| ) | |||
| node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( | |||
| node_retry_resp = self._workflow_cycle_manager.workflow_node_retry_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| @@ -338,20 +336,15 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( | |||
| workflow_run=workflow_run, event=event | |||
| ) | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( | |||
| workflow_execution_id=self._workflow_run_id, event=event | |||
| ) | |||
| node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| session.commit() | |||
| node_start_resp = self._workflow_cycle_manager.workflow_node_start_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if node_start_resp: | |||
| yield node_start_resp | |||
| @@ -359,15 +352,15 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| # Record files if it's an answer node or end node | |||
| if event.node_type in [NodeType.ANSWER, NodeType.END]: | |||
| self._recorded_files.extend( | |||
| self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {}) | |||
| self._workflow_cycle_manager.fetch_files_from_node_outputs(event.outputs or {}) | |||
| ) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( | |||
| event=event | |||
| ) | |||
| node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||
| node_finish_resp = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| @@ -383,11 +376,11 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| | QueueNodeInLoopFailedEvent | |||
| | QueueNodeExceptionEvent, | |||
| ): | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( | |||
| event=event | |||
| ) | |||
| node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||
| node_finish_resp = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| @@ -399,132 +392,90 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| parallel_start_resp = ( | |||
| self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| ) | |||
| parallel_start_resp = self._workflow_cycle_manager.workflow_parallel_branch_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield parallel_start_resp | |||
| elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| parallel_finish_resp = ( | |||
| self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| parallel_finish_resp = ( | |||
| self._workflow_cycle_manager.workflow_parallel_branch_finished_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| ) | |||
| yield parallel_finish_resp | |||
| elif isinstance(event, QueueIterationStartEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| iter_start_resp = self._workflow_cycle_manager.workflow_iteration_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield iter_start_resp | |||
| elif isinstance(event, QueueIterationNextEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| iter_next_resp = self._workflow_cycle_manager.workflow_iteration_next_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield iter_next_resp | |||
| elif isinstance(event, QueueIterationCompletedEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| iter_finish_resp = self._workflow_cycle_manager.workflow_iteration_completed_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield iter_finish_resp | |||
| elif isinstance(event, QueueLoopStartEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| loop_start_resp = self._workflow_cycle_manager.workflow_loop_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield loop_start_resp | |||
| elif isinstance(event, QueueLoopNextEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| loop_next_resp = self._workflow_cycle_manager.workflow_loop_next_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield loop_next_resp | |||
| elif isinstance(event, QueueLoopCompletedEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| loop_finish_resp = self._workflow_cycle_manager.workflow_loop_completed_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield loop_finish_resp | |||
| elif isinstance(event, QueueWorkflowSucceededEvent): | |||
| @@ -535,10 +486,8 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( | |||
| session=session, | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| @@ -546,10 +495,11 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| trace_manager=trace_manager, | |||
| ) | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| self._base_task_pipeline._queue_manager.publish( | |||
| @@ -562,10 +512,8 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( | |||
| session=session, | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| @@ -573,10 +521,11 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| self._base_task_pipeline._queue_manager.publish( | |||
| @@ -589,26 +538,25 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( | |||
| session=session, | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.FAILED, | |||
| error=event.error, | |||
| error_message=event.error, | |||
| conversation_id=self._conversation_id, | |||
| trace_manager=trace_manager, | |||
| exceptions_count=event.exceptions_count, | |||
| ) | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) | |||
| err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) | |||
| err = self._base_task_pipeline._handle_error( | |||
| event=err_event, session=session, message_id=self._message_id | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| yield self._base_task_pipeline._error_to_stream_response(err) | |||
| @@ -616,21 +564,19 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| elif isinstance(event, QueueStopEvent): | |||
| if self._workflow_run_id and graph_runtime_state: | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( | |||
| session=session, | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.STOPPED, | |||
| error=event.get_stop_reason(), | |||
| error_message=event.get_stop_reason(), | |||
| conversation_id=self._conversation_id, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| # Save message | |||
| self._save_message(session=session, graph_runtime_state=graph_runtime_state) | |||
| @@ -711,7 +657,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| yield self._message_end_to_stream_response() | |||
| elif isinstance(event, QueueAgentLogEvent): | |||
| yield self._workflow_cycle_manager._handle_agent_log( | |||
| yield self._workflow_cycle_manager.handle_agent_log( | |||
| task_id=self._application_generate_entity.task_id, event=event | |||
| ) | |||
| else: | |||
| @@ -18,16 +18,19 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | |||
| from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager | |||
| from core.app.apps.workflow.app_runner import WorkflowAppRunner | |||
| from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter | |||
| from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||
| from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity | |||
| from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @@ -136,9 +139,22 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| # Create workflow node execution repository | |||
| # Create repositories | |||
| # | |||
| # Create session factory | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| # Create workflow execution(aka workflow run) repository | |||
| if invoke_from == InvokeFrom.DEBUGGER: | |||
| workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING | |||
| else: | |||
| workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=workflow_triggered_from, | |||
| ) | |||
| # Create workflow node execution repository | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| @@ -152,6 +168,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| user=user, | |||
| application_generate_entity=application_generate_entity, | |||
| invoke_from=invoke_from, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| @@ -165,6 +182,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| user: Union[Account, EndUser], | |||
| application_generate_entity: WorkflowAppGenerateEntity, | |||
| invoke_from: InvokeFrom, | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| streaming: bool = True, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| @@ -209,6 +227,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow=workflow, | |||
| queue_manager=queue_manager, | |||
| user=user, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| stream=streaming, | |||
| ) | |||
| @@ -262,6 +281,17 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| # Create repositories | |||
| # | |||
| # Create session factory | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| # Create workflow execution(aka workflow run) repository | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, | |||
| ) | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| @@ -278,6 +308,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| user=user, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| ) | |||
| @@ -327,6 +358,17 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| # Create repositories | |||
| # | |||
| # Create session factory | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| # Create workflow execution(aka workflow run) repository | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, | |||
| ) | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| @@ -343,6 +385,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| user=user, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| ) | |||
| @@ -400,6 +443,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow: Workflow, | |||
| queue_manager: AppQueueManager, | |||
| user: Union[Account, EndUser], | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| stream: bool = False, | |||
| ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: | |||
| @@ -419,8 +463,9 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow=workflow, | |||
| queue_manager=queue_manager, | |||
| user=user, | |||
| stream=stream, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| stream=stream, | |||
| ) | |||
| try: | |||
| @@ -0,0 +1,591 @@ | |||
| import logging | |||
| import time | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import ( | |||
| InvokeFrom, | |||
| WorkflowAppGenerateEntity, | |||
| ) | |||
| from core.app.entities.queue_entities import ( | |||
| QueueAgentLogEvent, | |||
| QueueErrorEvent, | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueLoopCompletedEvent, | |||
| QueueLoopNextEvent, | |||
| QueueLoopStartEvent, | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeInLoopFailedEvent, | |||
| QueueNodeRetryEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| QueueParallelBranchRunStartedEvent, | |||
| QueueParallelBranchRunSucceededEvent, | |||
| QueuePingEvent, | |||
| QueueStopEvent, | |||
| QueueTextChunkEvent, | |||
| QueueWorkflowFailedEvent, | |||
| QueueWorkflowPartialSuccessEvent, | |||
| QueueWorkflowStartedEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| ) | |||
| from core.app.entities.task_entities import ( | |||
| ErrorStreamResponse, | |||
| MessageAudioEndStreamResponse, | |||
| MessageAudioStreamResponse, | |||
| StreamResponse, | |||
| TextChunkStreamResponse, | |||
| WorkflowAppBlockingResponse, | |||
| WorkflowAppStreamResponse, | |||
| WorkflowFinishStreamResponse, | |||
| WorkflowStartStreamResponse, | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.workflow_execution_entities import WorkflowExecution | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.enums import CreatorUserRole | |||
| from models.model import EndUser | |||
| from models.workflow import ( | |||
| Workflow, | |||
| WorkflowAppLog, | |||
| WorkflowAppLogCreatedFrom, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| class WorkflowAppGenerateTaskPipeline: | |||
| """ | |||
| WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| application_generate_entity: WorkflowAppGenerateEntity, | |||
| workflow: Workflow, | |||
| queue_manager: AppQueueManager, | |||
| user: Union[Account, EndUser], | |||
| stream: bool, | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| ) -> None: | |||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| stream=stream, | |||
| ) | |||
| if isinstance(user, EndUser): | |||
| self._user_id = user.id | |||
| user_session_id = user.session_id | |||
| self._created_by_role = CreatorUserRole.END_USER | |||
| elif isinstance(user, Account): | |||
| self._user_id = user.id | |||
| user_session_id = user.id | |||
| self._created_by_role = CreatorUserRole.ACCOUNT | |||
| else: | |||
| raise ValueError(f"Invalid user type: {type(user)}") | |||
| self._workflow_cycle_manager = WorkflowCycleManager( | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_system_variables={ | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| SystemVariableKey.USER_ID: user_session_id, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||
| }, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| ) | |||
| self._application_generate_entity = application_generate_entity | |||
| self._workflow_id = workflow.id | |||
| self._workflow_features_dict = workflow.features_dict | |||
| self._task_state = WorkflowTaskState() | |||
| self._workflow_run_id = "" | |||
| def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: | |||
| """ | |||
| Process generate task pipeline. | |||
| :return: | |||
| """ | |||
| generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | |||
| if self._base_task_pipeline._stream: | |||
| return self._to_stream_response(generator) | |||
| else: | |||
| return self._to_blocking_response(generator) | |||
| def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: | |||
| """ | |||
| To blocking response. | |||
| :return: | |||
| """ | |||
| for stream_response in generator: | |||
| if isinstance(stream_response, ErrorStreamResponse): | |||
| raise stream_response.err | |||
| elif isinstance(stream_response, WorkflowFinishStreamResponse): | |||
| response = WorkflowAppBlockingResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run_id=stream_response.data.id, | |||
| data=WorkflowAppBlockingResponse.Data( | |||
| id=stream_response.data.id, | |||
| workflow_id=stream_response.data.workflow_id, | |||
| status=stream_response.data.status, | |||
| outputs=stream_response.data.outputs, | |||
| error=stream_response.data.error, | |||
| elapsed_time=stream_response.data.elapsed_time, | |||
| total_tokens=stream_response.data.total_tokens, | |||
| total_steps=stream_response.data.total_steps, | |||
| created_at=int(stream_response.data.created_at), | |||
| finished_at=int(stream_response.data.finished_at), | |||
| ), | |||
| ) | |||
| return response | |||
| else: | |||
| continue | |||
| raise ValueError("queue listening stopped unexpectedly.") | |||
| def _to_stream_response( | |||
| self, generator: Generator[StreamResponse, None, None] | |||
| ) -> Generator[WorkflowAppStreamResponse, None, None]: | |||
| """ | |||
| To stream response. | |||
| :return: | |||
| """ | |||
| workflow_run_id = None | |||
| for stream_response in generator: | |||
| if isinstance(stream_response, WorkflowStartStreamResponse): | |||
| workflow_run_id = stream_response.workflow_run_id | |||
| yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) | |||
| def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): | |||
| if not publisher: | |||
| return None | |||
| audio_msg = publisher.check_and_get_audio() | |||
| if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": | |||
| return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) | |||
| return None | |||
| def _wrapper_process_stream_response( | |||
| self, trace_manager: Optional[TraceQueueManager] = None | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| tts_publisher = None | |||
| task_id = self._application_generate_entity.task_id | |||
| tenant_id = self._application_generate_entity.app_config.tenant_id | |||
| features_dict = self._workflow_features_dict | |||
| if ( | |||
| features_dict.get("text_to_speech") | |||
| and features_dict["text_to_speech"].get("enabled") | |||
| and features_dict["text_to_speech"].get("autoPlay") == "enabled" | |||
| ): | |||
| tts_publisher = AppGeneratorTTSPublisher( | |||
| tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language") | |||
| ) | |||
| for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): | |||
| while True: | |||
| audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) | |||
| if audio_response: | |||
| yield audio_response | |||
| else: | |||
| break | |||
| yield response | |||
| start_listener_time = time.time() | |||
| while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: | |||
| try: | |||
| if not tts_publisher: | |||
| break | |||
| audio_trunk = tts_publisher.check_and_get_audio() | |||
| if audio_trunk is None: | |||
| # release cpu | |||
| # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) | |||
| time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME) | |||
| continue | |||
| if audio_trunk.status == "finish": | |||
| break | |||
| else: | |||
| yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) | |||
| except Exception: | |||
| logger.exception(f"Fails to get audio trunk, task_id: {task_id}") | |||
| break | |||
| if tts_publisher: | |||
| yield MessageAudioEndStreamResponse(audio="", task_id=task_id) | |||
| def _process_stream_response( | |||
| self, | |||
| tts_publisher: Optional[AppGeneratorTTSPublisher] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| """ | |||
| Process stream response. | |||
| :return: | |||
| """ | |||
| graph_runtime_state = None | |||
| for queue_message in self._base_task_pipeline._queue_manager.listen(): | |||
| event = queue_message.event | |||
| if isinstance(event, QueuePingEvent): | |||
| yield self._base_task_pipeline._ping_stream_response() | |||
| elif isinstance(event, QueueErrorEvent): | |||
| err = self._base_task_pipeline._handle_error(event=event) | |||
| yield self._base_task_pipeline._error_to_stream_response(err) | |||
| break | |||
| elif isinstance(event, QueueWorkflowStartedEvent): | |||
| # override graph runtime state | |||
| graph_runtime_state = event.graph_runtime_state | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # init workflow run | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start( | |||
| session=session, | |||
| workflow_id=self._workflow_id, | |||
| ) | |||
| self._workflow_run_id = workflow_execution.id | |||
| start_resp = self._workflow_cycle_manager.workflow_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| yield start_resp | |||
| elif isinstance( | |||
| event, | |||
| QueueNodeRetryEvent, | |||
| ): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| response = self._workflow_cycle_manager.workflow_node_retry_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| session.commit() | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueNodeStartedEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( | |||
| workflow_execution_id=self._workflow_run_id, event=event | |||
| ) | |||
| node_start_response = self._workflow_cycle_manager.workflow_node_start_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if node_start_response: | |||
| yield node_start_response | |||
| elif isinstance(event, QueueNodeSucceededEvent): | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( | |||
| event=event | |||
| ) | |||
| node_success_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if node_success_response: | |||
| yield node_success_response | |||
| elif isinstance( | |||
| event, | |||
| QueueNodeFailedEvent | |||
| | QueueNodeInIterationFailedEvent | |||
| | QueueNodeInLoopFailedEvent | |||
| | QueueNodeExceptionEvent, | |||
| ): | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( | |||
| event=event, | |||
| ) | |||
| node_failed_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if node_failed_response: | |||
| yield node_failed_response | |||
| elif isinstance(event, QueueParallelBranchRunStartedEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| parallel_start_resp = self._workflow_cycle_manager.workflow_parallel_branch_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield parallel_start_resp | |||
| elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| parallel_finish_resp = ( | |||
| self._workflow_cycle_manager.workflow_parallel_branch_finished_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| ) | |||
| yield parallel_finish_resp | |||
| elif isinstance(event, QueueIterationStartEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| iter_start_resp = self._workflow_cycle_manager.workflow_iteration_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield iter_start_resp | |||
| elif isinstance(event, QueueIterationNextEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| iter_next_resp = self._workflow_cycle_manager.workflow_iteration_next_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield iter_next_resp | |||
| elif isinstance(event, QueueIterationCompletedEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| iter_finish_resp = self._workflow_cycle_manager.workflow_iteration_completed_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield iter_finish_resp | |||
| elif isinstance(event, QueueLoopStartEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| loop_start_resp = self._workflow_cycle_manager.workflow_loop_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield loop_start_resp | |||
| elif isinstance(event, QueueLoopNextEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| loop_next_resp = self._workflow_cycle_manager.workflow_loop_next_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield loop_next_resp | |||
| elif isinstance(event, QueueLoopCompletedEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| loop_finish_resp = self._workflow_cycle_manager.workflow_loop_completed_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield loop_finish_resp | |||
| elif isinstance(event, QueueWorkflowSucceededEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( | |||
| workflow_run_id=self._workflow_run_id, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) | |||
| workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| elif isinstance(event, QueueWorkflowPartialSuccessEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( | |||
| workflow_run_id=self._workflow_run_id, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| exceptions_count=event.exceptions_count, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) | |||
| workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( | |||
| workflow_run_id=self._workflow_run_id, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.FAILED | |||
| if isinstance(event, QueueWorkflowFailedEvent) | |||
| else WorkflowRunStatus.STOPPED, | |||
| error_message=event.error | |||
| if isinstance(event, QueueWorkflowFailedEvent) | |||
| else event.get_stop_reason(), | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) | |||
| workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| elif isinstance(event, QueueTextChunkEvent): | |||
| delta_text = event.text | |||
| if delta_text is None: | |||
| continue | |||
| # only publish tts message at text chunk streaming | |||
| if tts_publisher: | |||
| tts_publisher.publish(queue_message) | |||
| self._task_state.answer += delta_text | |||
| yield self._text_chunk_to_stream_response( | |||
| delta_text, from_variable_selector=event.from_variable_selector | |||
| ) | |||
| elif isinstance(event, QueueAgentLogEvent): | |||
| yield self._workflow_cycle_manager.handle_agent_log( | |||
| task_id=self._application_generate_entity.task_id, event=event | |||
| ) | |||
| else: | |||
| continue | |||
| if tts_publisher: | |||
| tts_publisher.publish(None) | |||
| def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: | |||
| workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id)) | |||
| assert workflow_run is not None | |||
| invoke_from = self._application_generate_entity.invoke_from | |||
| if invoke_from == InvokeFrom.SERVICE_API: | |||
| created_from = WorkflowAppLogCreatedFrom.SERVICE_API | |||
| elif invoke_from == InvokeFrom.EXPLORE: | |||
| created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP | |||
| elif invoke_from == InvokeFrom.WEB_APP: | |||
| created_from = WorkflowAppLogCreatedFrom.WEB_APP | |||
| else: | |||
| # not save log for debugging | |||
| return | |||
| workflow_app_log = WorkflowAppLog() | |||
| workflow_app_log.tenant_id = workflow_run.tenant_id | |||
| workflow_app_log.app_id = workflow_run.app_id | |||
| workflow_app_log.workflow_id = workflow_run.workflow_id | |||
| workflow_app_log.workflow_run_id = workflow_run.id | |||
| workflow_app_log.created_from = created_from.value | |||
| workflow_app_log.created_by_role = self._created_by_role | |||
| workflow_app_log.created_by = self._user_id | |||
| session.add(workflow_app_log) | |||
| session.commit() | |||
| def _text_chunk_to_stream_response( | |||
| self, text: str, from_variable_selector: Optional[list[str]] = None | |||
| ) -> TextChunkStreamResponse: | |||
| """ | |||
| Handle completed event. | |||
| :param text: text | |||
| :return: | |||
| """ | |||
| response = TextChunkStreamResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector), | |||
| ) | |||
| return response | |||
| @@ -190,7 +190,7 @@ class WorkflowStartStreamResponse(StreamResponse): | |||
| id: str | |||
| workflow_id: str | |||
| sequence_number: int | |||
| inputs: dict | |||
| inputs: Mapping[str, Any] | |||
| created_at: int | |||
| event: StreamEvent = StreamEvent.WORKFLOW_STARTED | |||
| @@ -212,7 +212,7 @@ class WorkflowFinishStreamResponse(StreamResponse): | |||
| workflow_id: str | |||
| sequence_number: int | |||
| status: str | |||
| outputs: Optional[dict] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| error: Optional[str] = None | |||
| elapsed_time: float | |||
| total_tokens: int | |||
| @@ -788,7 +788,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): | |||
| id: str | |||
| workflow_id: str | |||
| status: str | |||
| outputs: Optional[dict] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| error: Optional[str] = None | |||
| elapsed_time: float | |||
| total_tokens: int | |||
| @@ -30,6 +30,7 @@ from core.ops.entities.trace_entity import ( | |||
| WorkflowTraceInfo, | |||
| ) | |||
| from core.ops.utils import get_message_data | |||
| from core.workflow.entities.workflow_execution_entities import WorkflowExecution | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig | |||
| @@ -373,7 +374,7 @@ class TraceTask: | |||
| self, | |||
| trace_type: Any, | |||
| message_id: Optional[str] = None, | |||
| workflow_run: Optional[WorkflowRun] = None, | |||
| workflow_execution: Optional[WorkflowExecution] = None, | |||
| conversation_id: Optional[str] = None, | |||
| user_id: Optional[str] = None, | |||
| timer: Optional[Any] = None, | |||
| @@ -381,7 +382,7 @@ class TraceTask: | |||
| ): | |||
| self.trace_type = trace_type | |||
| self.message_id = message_id | |||
| self.workflow_run_id = workflow_run.id if workflow_run else None | |||
| self.workflow_run_id = workflow_execution.id if workflow_execution else None | |||
| self.conversation_id = conversation_id | |||
| self.user_id = user_id | |||
| self.timer = timer | |||
| @@ -0,0 +1,242 @@ | |||
| """ | |||
| SQLAlchemy implementation of the WorkflowExecutionRepository. | |||
| """ | |||
| import json | |||
| import logging | |||
| from typing import Optional, Union | |||
| from sqlalchemy import select | |||
| from sqlalchemy.engine import Engine | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.workflow.entities.workflow_execution_entities import ( | |||
| WorkflowExecution, | |||
| WorkflowExecutionStatus, | |||
| WorkflowType, | |||
| ) | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from models import ( | |||
| Account, | |||
| CreatorUserRole, | |||
| EndUser, | |||
| WorkflowRun, | |||
| ) | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): | |||
| """ | |||
| SQLAlchemy implementation of the WorkflowExecutionRepository interface. | |||
| This implementation supports multi-tenancy by filtering operations based on tenant_id. | |||
| Each method creates its own session, handles the transaction, and commits changes | |||
| to the database. This prevents long-running connections in the workflow core. | |||
| This implementation also includes an in-memory cache for workflow executions to improve | |||
| performance by reducing database queries. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| session_factory: sessionmaker | Engine, | |||
| user: Union[Account, EndUser], | |||
| app_id: Optional[str], | |||
| triggered_from: Optional[WorkflowRunTriggeredFrom], | |||
| ): | |||
| """ | |||
| Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. | |||
| Args: | |||
| session_factory: SQLAlchemy sessionmaker or engine for creating sessions | |||
| user: Account or EndUser object containing tenant_id, user ID, and role information | |||
| app_id: App ID for filtering by application (can be None) | |||
| triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN) | |||
| """ | |||
| # If an engine is provided, create a sessionmaker from it | |||
| if isinstance(session_factory, Engine): | |||
| self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) | |||
| elif isinstance(session_factory, sessionmaker): | |||
| self._session_factory = session_factory | |||
| else: | |||
| raise ValueError( | |||
| f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" | |||
| ) | |||
| # Extract tenant_id from user | |||
| tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id | |||
| if not tenant_id: | |||
| raise ValueError("User must have a tenant_id or current_tenant_id") | |||
| self._tenant_id = tenant_id | |||
| # Store app context | |||
| self._app_id = app_id | |||
| # Extract user context | |||
| self._triggered_from = triggered_from | |||
| self._creator_user_id = user.id | |||
| # Determine user role based on user type | |||
| self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER | |||
| # Initialize in-memory cache for workflow executions | |||
| # Key: execution_id, Value: WorkflowRun (DB model) | |||
| self._execution_cache: dict[str, WorkflowRun] = {} | |||
| def _to_domain_model(self, db_model: WorkflowRun) -> WorkflowExecution: | |||
| """ | |||
| Convert a database model to a domain model. | |||
| Args: | |||
| db_model: The database model to convert | |||
| Returns: | |||
| The domain model | |||
| """ | |||
| # Parse JSON fields | |||
| inputs = db_model.inputs_dict | |||
| outputs = db_model.outputs_dict | |||
| graph = db_model.graph_dict | |||
| # Convert status to domain enum | |||
| status = WorkflowExecutionStatus(db_model.status) | |||
| return WorkflowExecution( | |||
| id=db_model.id, | |||
| workflow_id=db_model.workflow_id, | |||
| sequence_number=db_model.sequence_number, | |||
| type=WorkflowType(db_model.type), | |||
| workflow_version=db_model.version, | |||
| graph=graph, | |||
| inputs=inputs, | |||
| outputs=outputs, | |||
| status=status, | |||
| error_message=db_model.error or "", | |||
| total_tokens=db_model.total_tokens, | |||
| total_steps=db_model.total_steps, | |||
| exceptions_count=db_model.exceptions_count, | |||
| started_at=db_model.created_at, | |||
| finished_at=db_model.finished_at, | |||
| ) | |||
| def _to_db_model(self, domain_model: WorkflowExecution) -> WorkflowRun: | |||
| """ | |||
| Convert a domain model to a database model. | |||
| Args: | |||
| domain_model: The domain model to convert | |||
| Returns: | |||
| The database model | |||
| """ | |||
| # Use values from constructor if provided | |||
| if not self._triggered_from: | |||
| raise ValueError("triggered_from is required in repository constructor") | |||
| if not self._creator_user_id: | |||
| raise ValueError("created_by is required in repository constructor") | |||
| if not self._creator_user_role: | |||
| raise ValueError("created_by_role is required in repository constructor") | |||
| db_model = WorkflowRun() | |||
| db_model.id = domain_model.id | |||
| db_model.tenant_id = self._tenant_id | |||
| if self._app_id is not None: | |||
| db_model.app_id = self._app_id | |||
| db_model.workflow_id = domain_model.workflow_id | |||
| db_model.triggered_from = self._triggered_from | |||
| db_model.sequence_number = domain_model.sequence_number | |||
| db_model.type = domain_model.type | |||
| db_model.version = domain_model.workflow_version | |||
| db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None | |||
| db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None | |||
| db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None | |||
| db_model.status = domain_model.status | |||
| db_model.error = domain_model.error_message if domain_model.error_message else None | |||
| db_model.total_tokens = domain_model.total_tokens | |||
| db_model.total_steps = domain_model.total_steps | |||
| db_model.exceptions_count = domain_model.exceptions_count | |||
| db_model.created_by_role = self._creator_user_role | |||
| db_model.created_by = self._creator_user_id | |||
| db_model.created_at = domain_model.started_at | |||
| db_model.finished_at = domain_model.finished_at | |||
| # Calculate elapsed time if finished_at is available | |||
| if domain_model.finished_at: | |||
| db_model.elapsed_time = (domain_model.finished_at - domain_model.started_at).total_seconds() | |||
| else: | |||
| db_model.elapsed_time = 0 | |||
| return db_model | |||
| def save(self, execution: WorkflowExecution) -> None: | |||
| """ | |||
| Save or update a WorkflowExecution domain entity to the database. | |||
| This method serves as a domain-to-database adapter that: | |||
| 1. Converts the domain entity to its database representation | |||
| 2. Persists the database model using SQLAlchemy's merge operation | |||
| 3. Maintains proper multi-tenancy by including tenant context during conversion | |||
| 4. Updates the in-memory cache for faster subsequent lookups | |||
| The method handles both creating new records and updating existing ones through | |||
| SQLAlchemy's merge operation. | |||
| Args: | |||
| execution: The WorkflowExecution domain entity to persist | |||
| """ | |||
| # Convert domain model to database model using tenant context and other attributes | |||
| db_model = self._to_db_model(execution) | |||
| # Create a new database session | |||
| with self._session_factory() as session: | |||
| # SQLAlchemy merge intelligently handles both insert and update operations | |||
| # based on the presence of the primary key | |||
| session.merge(db_model) | |||
| session.commit() | |||
| # Update the in-memory cache for faster subsequent lookups | |||
| logger.debug(f"Updating cache for execution_id: {db_model.id}") | |||
| self._execution_cache[db_model.id] = db_model | |||
| def get(self, execution_id: str) -> Optional[WorkflowExecution]: | |||
| """ | |||
| Retrieve a WorkflowExecution by its ID. | |||
| First checks the in-memory cache, and if not found, queries the database. | |||
| If found in the database, adds it to the cache for future lookups. | |||
| Args: | |||
| execution_id: The workflow execution ID | |||
| Returns: | |||
| The WorkflowExecution instance if found, None otherwise | |||
| """ | |||
| # First check the cache | |||
| if execution_id in self._execution_cache: | |||
| logger.debug(f"Cache hit for execution_id: {execution_id}") | |||
| # Convert cached DB model to domain model | |||
| cached_db_model = self._execution_cache[execution_id] | |||
| return self._to_domain_model(cached_db_model) | |||
| # If not in cache, query the database | |||
| logger.debug(f"Cache miss for execution_id: {execution_id}, querying database") | |||
| with self._session_factory() as session: | |||
| stmt = select(WorkflowRun).where( | |||
| WorkflowRun.id == execution_id, | |||
| WorkflowRun.tenant_id == self._tenant_id, | |||
| ) | |||
| if self._app_id: | |||
| stmt = stmt.where(WorkflowRun.app_id == self._app_id) | |||
| db_model = session.scalar(stmt) | |||
| if db_model: | |||
| # Add DB model to cache | |||
| self._execution_cache[execution_id] = db_model | |||
| # Convert to domain model and return | |||
| return self._to_domain_model(db_model) | |||
| return None | |||
| @@ -0,0 +1,91 @@ | |||
| """ | |||
| Domain entities for workflow execution. | |||
| Models are independent of the storage mechanism and don't contain | |||
| implementation details like tenant_id, app_id, etc. | |||
| """ | |||
| from collections.abc import Mapping | |||
| from datetime import UTC, datetime | |||
| from enum import StrEnum | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, Field | |||
| class WorkflowType(StrEnum): | |||
| """ | |||
| Workflow Type Enum for domain layer | |||
| """ | |||
| WORKFLOW = "workflow" | |||
| CHAT = "chat" | |||
| class WorkflowExecutionStatus(StrEnum): | |||
| RUNNING = "running" | |||
| SUCCEEDED = "succeeded" | |||
| FAILED = "failed" | |||
| STOPPED = "stopped" | |||
| PARTIAL_SUCCEEDED = "partial-succeeded" | |||
| class WorkflowExecution(BaseModel): | |||
| """ | |||
| Domain model for workflow execution based on WorkflowRun but without | |||
| user, tenant, and app attributes. | |||
| """ | |||
| id: str = Field(...) | |||
| workflow_id: str = Field(...) | |||
| workflow_version: str = Field(...) | |||
| sequence_number: int = Field(...) | |||
| type: WorkflowType = Field(...) | |||
| graph: Mapping[str, Any] = Field(...) | |||
| inputs: Mapping[str, Any] = Field(...) | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING | |||
| error_message: str = Field(default="") | |||
| total_tokens: int = Field(default=0) | |||
| total_steps: int = Field(default=0) | |||
| exceptions_count: int = Field(default=0) | |||
| started_at: datetime = Field(...) | |||
| finished_at: Optional[datetime] = None | |||
| @property | |||
| def elapsed_time(self) -> float: | |||
| """ | |||
| Calculate elapsed time in seconds. | |||
| If workflow is not finished, use current time. | |||
| """ | |||
| end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) | |||
| return (end_time - self.started_at).total_seconds() | |||
| @classmethod | |||
| def new( | |||
| cls, | |||
| *, | |||
| id: str, | |||
| workflow_id: str, | |||
| sequence_number: int, | |||
| type: WorkflowType, | |||
| workflow_version: str, | |||
| graph: Mapping[str, Any], | |||
| inputs: Mapping[str, Any], | |||
| started_at: datetime, | |||
| ) -> "WorkflowExecution": | |||
| return WorkflowExecution( | |||
| id=id, | |||
| workflow_id=workflow_id, | |||
| sequence_number=sequence_number, | |||
| type=type, | |||
| workflow_version=workflow_version, | |||
| graph=graph, | |||
| inputs=inputs, | |||
| status=WorkflowExecutionStatus.RUNNING, | |||
| started_at=started_at, | |||
| ) | |||
| @@ -0,0 +1,42 @@ | |||
| from typing import Optional, Protocol | |||
| from core.workflow.entities.workflow_execution_entities import WorkflowExecution | |||
| class WorkflowExecutionRepository(Protocol): | |||
| """ | |||
| Repository interface for WorkflowExecution. | |||
| This interface defines the contract for accessing and manipulating | |||
| WorkflowExecution data, regardless of the underlying storage mechanism. | |||
| Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), | |||
| and other implementation details should be handled at the implementation level, not in | |||
| the core interface. This keeps the core domain model clean and independent of specific | |||
| application domains or deployment scenarios. | |||
| """ | |||
| def save(self, execution: WorkflowExecution) -> None: | |||
| """ | |||
| Save or update a WorkflowExecution instance. | |||
| This method handles both creating new records and updating existing ones. | |||
| The implementation should determine whether to create or update based on | |||
| the execution's ID or other identifying fields. | |||
| Args: | |||
| execution: The WorkflowExecution instance to save or update | |||
| """ | |||
| ... | |||
| def get(self, execution_id: str) -> Optional[WorkflowExecution]: | |||
| """ | |||
| Retrieve a WorkflowExecution by its ID. | |||
| Args: | |||
| execution_id: The workflow execution ID | |||
| Returns: | |||
| The WorkflowExecution instance if found, None otherwise | |||
| """ | |||
| ... | |||
| @@ -3,6 +3,7 @@ import time | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| @@ -53,7 +54,9 @@ from core.app.entities.task_entities import ( | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.workflow_execution_entities import WorkflowExecution | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||
| from extensions.ext_database import db | |||
| @@ -83,6 +86,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| queue_manager: AppQueueManager, | |||
| user: Union[Account, EndUser], | |||
| stream: bool, | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| ) -> None: | |||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | |||
| @@ -111,6 +115,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||
| }, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| ) | |||
| @@ -258,17 +263,15 @@ class WorkflowAppGenerateTaskPipeline: | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # init workflow run | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start( | |||
| session=session, | |||
| workflow_id=self._workflow_id, | |||
| user_id=self._user_id, | |||
| created_by_role=self._created_by_role, | |||
| ) | |||
| self._workflow_run_id = workflow_run.id | |||
| start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| self._workflow_run_id = workflow_execution.id | |||
| start_resp = self._workflow_cycle_manager.workflow_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| session.commit() | |||
| yield start_resp | |||
| elif isinstance( | |||
| @@ -278,13 +281,11 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( | |||
| workflow_run=workflow_run, event=event | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( | |||
| response = self._workflow_cycle_manager.workflow_node_retry_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| @@ -297,27 +298,22 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( | |||
| workflow_run=workflow_run, event=event | |||
| ) | |||
| node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| session.commit() | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( | |||
| workflow_execution_id=self._workflow_run_id, event=event | |||
| ) | |||
| node_start_response = self._workflow_cycle_manager.workflow_node_start_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if node_start_response: | |||
| yield node_start_response | |||
| elif isinstance(event, QueueNodeSucceededEvent): | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( | |||
| event=event | |||
| ) | |||
| node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||
| node_success_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| @@ -332,10 +328,10 @@ class WorkflowAppGenerateTaskPipeline: | |||
| | QueueNodeInLoopFailedEvent | |||
| | QueueNodeExceptionEvent, | |||
| ): | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( | |||
| workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( | |||
| event=event, | |||
| ) | |||
| node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||
| node_failed_response = self._workflow_cycle_manager.workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| @@ -348,18 +344,11 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| parallel_start_resp = ( | |||
| self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| ) | |||
| parallel_start_resp = self._workflow_cycle_manager.workflow_parallel_branch_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield parallel_start_resp | |||
| @@ -367,18 +356,13 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| parallel_finish_resp = ( | |||
| self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| parallel_finish_resp = ( | |||
| self._workflow_cycle_manager.workflow_parallel_branch_finished_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| ) | |||
| yield parallel_finish_resp | |||
| @@ -386,16 +370,11 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| iter_start_resp = self._workflow_cycle_manager.workflow_iteration_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield iter_start_resp | |||
| @@ -403,16 +382,11 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| iter_next_resp = self._workflow_cycle_manager.workflow_iteration_next_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield iter_next_resp | |||
| @@ -420,16 +394,11 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| iter_finish_resp = self._workflow_cycle_manager.workflow_iteration_completed_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield iter_finish_resp | |||
| @@ -437,16 +406,11 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| loop_start_resp = self._workflow_cycle_manager.workflow_loop_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield loop_start_resp | |||
| @@ -454,16 +418,11 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| loop_next_resp = self._workflow_cycle_manager.workflow_loop_next_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield loop_next_resp | |||
| @@ -471,16 +430,11 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| loop_finish_resp = self._workflow_cycle_manager.workflow_loop_completed_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield loop_finish_resp | |||
| @@ -491,10 +445,8 @@ class WorkflowAppGenerateTaskPipeline: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( | |||
| session=session, | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| @@ -503,12 +455,12 @@ class WorkflowAppGenerateTaskPipeline: | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | |||
| self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| session.commit() | |||
| @@ -520,10 +472,8 @@ class WorkflowAppGenerateTaskPipeline: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( | |||
| session=session, | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| @@ -533,10 +483,12 @@ class WorkflowAppGenerateTaskPipeline: | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | |||
| self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| session.commit() | |||
| @@ -548,26 +500,28 @@ class WorkflowAppGenerateTaskPipeline: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( | |||
| session=session, | |||
| workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.FAILED | |||
| if isinstance(event, QueueWorkflowFailedEvent) | |||
| else WorkflowRunStatus.STOPPED, | |||
| error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), | |||
| error_message=event.error | |||
| if isinstance(event, QueueWorkflowFailedEvent) | |||
| else event.get_stop_reason(), | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | |||
| self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| workflow_finish_resp = self._workflow_cycle_manager.workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution=workflow_execution, | |||
| ) | |||
| session.commit() | |||
| @@ -586,7 +540,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| delta_text, from_variable_selector=event.from_variable_selector | |||
| ) | |||
| elif isinstance(event, QueueAgentLogEvent): | |||
| yield self._workflow_cycle_manager._handle_agent_log( | |||
| yield self._workflow_cycle_manager.handle_agent_log( | |||
| task_id=self._application_generate_entity.task_id, event=event | |||
| ) | |||
| else: | |||
| @@ -595,11 +549,9 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if tts_publisher: | |||
| tts_publisher.publish(None) | |||
| def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None: | |||
| """ | |||
| Save workflow app log. | |||
| :return: | |||
| """ | |||
| def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: | |||
| workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id)) | |||
| assert workflow_run is not None | |||
| invoke_from = self._application_generate_entity.invoke_from | |||
| if invoke_from == InvokeFrom.SERVICE_API: | |||
| created_from = WorkflowAppLogCreatedFrom.SERVICE_API | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import time | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import UTC, datetime | |||
| @@ -8,7 +7,7 @@ from uuid import uuid4 | |||
| from sqlalchemy import func, select | |||
| from sqlalchemy.orm import Session | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity | |||
| from core.app.entities.queue_entities import ( | |||
| QueueAgentLogEvent, | |||
| QueueIterationCompletedEvent, | |||
| @@ -54,9 +53,11 @@ from core.workflow.entities.node_execution_entities import ( | |||
| NodeExecution, | |||
| NodeExecutionStatus, | |||
| ) | |||
| from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from models import ( | |||
| @@ -67,7 +68,6 @@ from models import ( | |||
| WorkflowNodeExecutionStatus, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| WorkflowRunTriggeredFrom, | |||
| ) | |||
| @@ -77,21 +77,20 @@ class WorkflowCycleManager: | |||
| *, | |||
| application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], | |||
| workflow_system_variables: dict[SystemVariableKey, Any], | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| ) -> None: | |||
| self._workflow_run: WorkflowRun | None = None | |||
| self._application_generate_entity = application_generate_entity | |||
| self._workflow_system_variables = workflow_system_variables | |||
| self._workflow_execution_repository = workflow_execution_repository | |||
| self._workflow_node_execution_repository = workflow_node_execution_repository | |||
| def _handle_workflow_run_start( | |||
| def handle_workflow_run_start( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| workflow_id: str, | |||
| user_id: str, | |||
| created_by_role: CreatorUserRole, | |||
| ) -> WorkflowRun: | |||
| ) -> WorkflowExecution: | |||
| workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) | |||
| workflow = session.scalar(workflow_stmt) | |||
| if not workflow: | |||
| @@ -110,157 +109,116 @@ class WorkflowCycleManager: | |||
| continue | |||
| inputs[f"sys.{key.value}"] = value | |||
| triggered_from = ( | |||
| WorkflowRunTriggeredFrom.DEBUGGING | |||
| if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER | |||
| else WorkflowRunTriggeredFrom.APP_RUN | |||
| ) | |||
| # handle special values | |||
| inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) | |||
| # init workflow run | |||
| # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this | |||
| workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) | |||
| workflow_run = WorkflowRun() | |||
| workflow_run.id = workflow_run_id | |||
| workflow_run.tenant_id = workflow.tenant_id | |||
| workflow_run.app_id = workflow.app_id | |||
| workflow_run.sequence_number = new_sequence_number | |||
| workflow_run.workflow_id = workflow.id | |||
| workflow_run.type = workflow.type | |||
| workflow_run.triggered_from = triggered_from.value | |||
| workflow_run.version = workflow.version | |||
| workflow_run.graph = workflow.graph | |||
| workflow_run.inputs = json.dumps(inputs) | |||
| workflow_run.status = WorkflowRunStatus.RUNNING | |||
| workflow_run.created_by_role = created_by_role | |||
| workflow_run.created_by = user_id | |||
| workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| session.add(workflow_run) | |||
| return workflow_run | |||
| def _handle_workflow_run_success( | |||
| execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) | |||
| execution = WorkflowExecution.new( | |||
| id=execution_id, | |||
| workflow_id=workflow.id, | |||
| sequence_number=new_sequence_number, | |||
| type=WorkflowType(workflow.type), | |||
| workflow_version=workflow.version, | |||
| graph=workflow.graph_dict, | |||
| inputs=inputs, | |||
| started_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| self._workflow_execution_repository.save(execution) | |||
| return execution | |||
| def handle_workflow_run_success( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| workflow_run_id: str, | |||
| start_at: float, | |||
| total_tokens: int, | |||
| total_steps: int, | |||
| outputs: Mapping[str, Any] | None = None, | |||
| conversation_id: Optional[str] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> WorkflowRun: | |||
| """ | |||
| Workflow run success | |||
| :param workflow_run_id: workflow run id | |||
| :param start_at: start time | |||
| :param total_tokens: total tokens | |||
| :param total_steps: total steps | |||
| :param outputs: outputs | |||
| :param conversation_id: conversation id | |||
| :return: | |||
| """ | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) | |||
| ) -> WorkflowExecution: | |||
| workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) | |||
| outputs = WorkflowEntry.handle_special_values(outputs) | |||
| workflow_run.status = WorkflowRunStatus.SUCCEEDED | |||
| workflow_run.outputs = json.dumps(outputs or {}) | |||
| workflow_run.elapsed_time = time.perf_counter() - start_at | |||
| workflow_run.total_tokens = total_tokens | |||
| workflow_run.total_steps = total_steps | |||
| workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED | |||
| workflow_execution.outputs = outputs or {} | |||
| workflow_execution.total_tokens = total_tokens | |||
| workflow_execution.total_steps = total_steps | |||
| workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| TraceTask( | |||
| TraceTaskName.WORKFLOW_TRACE, | |||
| workflow_run=workflow_run, | |||
| workflow_execution=workflow_execution, | |||
| conversation_id=conversation_id, | |||
| user_id=trace_manager.user_id, | |||
| ) | |||
| ) | |||
| return workflow_run | |||
| return workflow_execution | |||
| def _handle_workflow_run_partial_success( | |||
| def handle_workflow_run_partial_success( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| workflow_run_id: str, | |||
| start_at: float, | |||
| total_tokens: int, | |||
| total_steps: int, | |||
| outputs: Mapping[str, Any] | None = None, | |||
| exceptions_count: int = 0, | |||
| conversation_id: Optional[str] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> WorkflowRun: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) | |||
| ) -> WorkflowExecution: | |||
| execution = self._get_workflow_execution_or_raise_error(workflow_run_id) | |||
| outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) | |||
| workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCEEDED.value | |||
| workflow_run.outputs = json.dumps(outputs or {}) | |||
| workflow_run.elapsed_time = time.perf_counter() - start_at | |||
| workflow_run.total_tokens = total_tokens | |||
| workflow_run.total_steps = total_steps | |||
| workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_run.exceptions_count = exceptions_count | |||
| execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED | |||
| execution.outputs = outputs or {} | |||
| execution.total_tokens = total_tokens | |||
| execution.total_steps = total_steps | |||
| execution.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| execution.exceptions_count = exceptions_count | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| TraceTask( | |||
| TraceTaskName.WORKFLOW_TRACE, | |||
| workflow_run=workflow_run, | |||
| workflow_execution=execution, | |||
| conversation_id=conversation_id, | |||
| user_id=trace_manager.user_id, | |||
| ) | |||
| ) | |||
| return workflow_run | |||
| return execution | |||
| def _handle_workflow_run_failed( | |||
| def handle_workflow_run_failed( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| workflow_run_id: str, | |||
| start_at: float, | |||
| total_tokens: int, | |||
| total_steps: int, | |||
| status: WorkflowRunStatus, | |||
| error: str, | |||
| error_message: str, | |||
| conversation_id: Optional[str] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| exceptions_count: int = 0, | |||
| ) -> WorkflowRun: | |||
| """ | |||
| Workflow run failed | |||
| :param workflow_run_id: workflow run id | |||
| :param start_at: start time | |||
| :param total_tokens: total tokens | |||
| :param total_steps: total steps | |||
| :param status: status | |||
| :param error: error message | |||
| :return: | |||
| """ | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) | |||
| ) -> WorkflowExecution: | |||
| execution = self._get_workflow_execution_or_raise_error(workflow_run_id) | |||
| workflow_run.status = status.value | |||
| workflow_run.error = error | |||
| workflow_run.elapsed_time = time.perf_counter() - start_at | |||
| workflow_run.total_tokens = total_tokens | |||
| workflow_run.total_steps = total_steps | |||
| workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_run.exceptions_count = exceptions_count | |||
| execution.status = WorkflowExecutionStatus(status.value) | |||
| execution.error_message = error_message | |||
| execution.total_tokens = total_tokens | |||
| execution.total_steps = total_steps | |||
| execution.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| execution.exceptions_count = exceptions_count | |||
| # Use the instance repository to find running executions for a workflow run | |||
| running_domain_executions = self._workflow_node_execution_repository.get_running_executions( | |||
| workflow_run_id=workflow_run.id | |||
| workflow_run_id=execution.id | |||
| ) | |||
| # Update the domain models | |||
| @@ -269,7 +227,7 @@ class WorkflowCycleManager: | |||
| if domain_execution.node_execution_id: | |||
| # Update the domain model | |||
| domain_execution.status = NodeExecutionStatus.FAILED | |||
| domain_execution.error = error | |||
| domain_execution.error = error_message | |||
| domain_execution.finished_at = now | |||
| domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds() | |||
| @@ -280,15 +238,22 @@ class WorkflowCycleManager: | |||
| trace_manager.add_trace_task( | |||
| TraceTask( | |||
| TraceTaskName.WORKFLOW_TRACE, | |||
| workflow_run=workflow_run, | |||
| workflow_execution=execution, | |||
| conversation_id=conversation_id, | |||
| user_id=trace_manager.user_id, | |||
| ) | |||
| ) | |||
| return workflow_run | |||
| return execution | |||
| def handle_node_execution_start( | |||
| self, | |||
| *, | |||
| workflow_execution_id: str, | |||
| event: QueueNodeStartedEvent, | |||
| ) -> NodeExecution: | |||
| workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) | |||
| def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution: | |||
| # Create a domain model | |||
| created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| metadata = { | |||
| @@ -299,8 +264,8 @@ class WorkflowCycleManager: | |||
| domain_execution = NodeExecution( | |||
| id=str(uuid4()), | |||
| workflow_id=workflow_run.workflow_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_id=workflow_execution.workflow_id, | |||
| workflow_run_id=workflow_execution.id, | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| index=event.node_run_index, | |||
| node_execution_id=event.node_execution_id, | |||
| @@ -317,7 +282,7 @@ class WorkflowCycleManager: | |||
| return domain_execution | |||
| def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution: | |||
| def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution: | |||
| # Get the domain model from repository | |||
| domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) | |||
| if not domain_execution: | |||
| @@ -350,7 +315,7 @@ class WorkflowCycleManager: | |||
| return domain_execution | |||
| def _handle_workflow_node_execution_failed( | |||
| def handle_workflow_node_execution_failed( | |||
| self, | |||
| *, | |||
| event: QueueNodeFailedEvent | |||
| @@ -400,15 +365,10 @@ class WorkflowCycleManager: | |||
| return domain_execution | |||
| def _handle_workflow_node_execution_retried( | |||
| self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent | |||
| def handle_workflow_node_execution_retried( | |||
| self, *, workflow_execution_id: str, event: QueueNodeRetryEvent | |||
| ) -> NodeExecution: | |||
| """ | |||
| Workflow node execution failed | |||
| :param workflow_run: workflow run | |||
| :param event: queue node failed event | |||
| :return: | |||
| """ | |||
| workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) | |||
| created_at = event.start_at | |||
| finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| elapsed_time = (finished_at - created_at).total_seconds() | |||
| @@ -433,8 +393,8 @@ class WorkflowCycleManager: | |||
| # Create a domain model | |||
| domain_execution = NodeExecution( | |||
| id=str(uuid4()), | |||
| workflow_id=workflow_run.workflow_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_id=workflow_execution.workflow_id, | |||
| workflow_run_id=workflow_execution.id, | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| node_execution_id=event.node_execution_id, | |||
| node_id=event.node_id, | |||
| @@ -456,34 +416,34 @@ class WorkflowCycleManager: | |||
| return domain_execution | |||
| def _workflow_start_to_stream_response( | |||
| def workflow_start_to_stream_response( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| task_id: str, | |||
| workflow_run: WorkflowRun, | |||
| workflow_execution: WorkflowExecution, | |||
| ) -> WorkflowStartStreamResponse: | |||
| _ = session | |||
| return WorkflowStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_run_id=workflow_execution.id, | |||
| data=WorkflowStartStreamResponse.Data( | |||
| id=workflow_run.id, | |||
| workflow_id=workflow_run.workflow_id, | |||
| sequence_number=workflow_run.sequence_number, | |||
| inputs=dict(workflow_run.inputs_dict or {}), | |||
| created_at=int(workflow_run.created_at.timestamp()), | |||
| id=workflow_execution.id, | |||
| workflow_id=workflow_execution.workflow_id, | |||
| sequence_number=workflow_execution.sequence_number, | |||
| inputs=workflow_execution.inputs, | |||
| created_at=int(workflow_execution.started_at.timestamp()), | |||
| ), | |||
| ) | |||
| def _workflow_finish_to_stream_response( | |||
| def workflow_finish_to_stream_response( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| task_id: str, | |||
| workflow_run: WorkflowRun, | |||
| workflow_execution: WorkflowExecution, | |||
| ) -> WorkflowFinishStreamResponse: | |||
| created_by = None | |||
| workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id)) | |||
| assert workflow_run is not None | |||
| if workflow_run.created_by_role == CreatorUserRole.ACCOUNT: | |||
| stmt = select(Account).where(Account.id == workflow_run.created_by) | |||
| account = session.scalar(stmt) | |||
| @@ -504,28 +464,35 @@ class WorkflowCycleManager: | |||
| else: | |||
| raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") | |||
| # Handle the case where finished_at is None by using current time as default | |||
| finished_at_timestamp = ( | |||
| int(workflow_execution.finished_at.timestamp()) | |||
| if workflow_execution.finished_at | |||
| else int(datetime.now(UTC).timestamp()) | |||
| ) | |||
| return WorkflowFinishStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_run_id=workflow_execution.id, | |||
| data=WorkflowFinishStreamResponse.Data( | |||
| id=workflow_run.id, | |||
| workflow_id=workflow_run.workflow_id, | |||
| sequence_number=workflow_run.sequence_number, | |||
| status=workflow_run.status, | |||
| outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None, | |||
| error=workflow_run.error, | |||
| elapsed_time=workflow_run.elapsed_time, | |||
| total_tokens=workflow_run.total_tokens, | |||
| total_steps=workflow_run.total_steps, | |||
| id=workflow_execution.id, | |||
| workflow_id=workflow_execution.workflow_id, | |||
| sequence_number=workflow_execution.sequence_number, | |||
| status=workflow_execution.status, | |||
| outputs=workflow_execution.outputs, | |||
| error=workflow_execution.error_message, | |||
| elapsed_time=workflow_execution.elapsed_time, | |||
| total_tokens=workflow_execution.total_tokens, | |||
| total_steps=workflow_execution.total_steps, | |||
| created_by=created_by, | |||
| created_at=int(workflow_run.created_at.timestamp()), | |||
| finished_at=int(workflow_run.finished_at.timestamp()), | |||
| files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)), | |||
| exceptions_count=workflow_run.exceptions_count, | |||
| created_at=int(workflow_execution.started_at.timestamp()), | |||
| finished_at=finished_at_timestamp, | |||
| files=self.fetch_files_from_node_outputs(workflow_execution.outputs), | |||
| exceptions_count=workflow_execution.exceptions_count, | |||
| ), | |||
| ) | |||
| def _workflow_node_start_to_stream_response( | |||
| def workflow_node_start_to_stream_response( | |||
| self, | |||
| *, | |||
| event: QueueNodeStartedEvent, | |||
| @@ -571,7 +538,7 @@ class WorkflowCycleManager: | |||
| return response | |||
| def _workflow_node_finish_to_stream_response( | |||
| def workflow_node_finish_to_stream_response( | |||
| self, | |||
| *, | |||
| event: QueueNodeSucceededEvent | |||
| @@ -608,7 +575,7 @@ class WorkflowCycleManager: | |||
| execution_metadata=workflow_node_execution.metadata, | |||
| created_at=int(workflow_node_execution.created_at.timestamp()), | |||
| finished_at=int(workflow_node_execution.finished_at.timestamp()), | |||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), | |||
| files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| @@ -618,7 +585,7 @@ class WorkflowCycleManager: | |||
| ), | |||
| ) | |||
| def _workflow_node_retry_to_stream_response( | |||
| def workflow_node_retry_to_stream_response( | |||
| self, | |||
| *, | |||
| event: QueueNodeRetryEvent, | |||
| @@ -651,7 +618,7 @@ class WorkflowCycleManager: | |||
| execution_metadata=workflow_node_execution.metadata, | |||
| created_at=int(workflow_node_execution.created_at.timestamp()), | |||
| finished_at=int(workflow_node_execution.finished_at.timestamp()), | |||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), | |||
| files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| @@ -662,13 +629,16 @@ class WorkflowCycleManager: | |||
| ), | |||
| ) | |||
| def _workflow_parallel_branch_start_to_stream_response( | |||
| self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent | |||
| def workflow_parallel_branch_start_to_stream_response( | |||
| self, | |||
| *, | |||
| task_id: str, | |||
| workflow_execution_id: str, | |||
| event: QueueParallelBranchRunStartedEvent, | |||
| ) -> ParallelBranchStartStreamResponse: | |||
| _ = session | |||
| return ParallelBranchStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_run_id=workflow_execution_id, | |||
| data=ParallelBranchStartStreamResponse.Data( | |||
| parallel_id=event.parallel_id, | |||
| parallel_branch_id=event.parallel_start_node_id, | |||
| @@ -680,18 +650,16 @@ class WorkflowCycleManager: | |||
| ), | |||
| ) | |||
| def _workflow_parallel_branch_finished_to_stream_response( | |||
| def workflow_parallel_branch_finished_to_stream_response( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| task_id: str, | |||
| workflow_run: WorkflowRun, | |||
| workflow_execution_id: str, | |||
| event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, | |||
| ) -> ParallelBranchFinishedStreamResponse: | |||
| _ = session | |||
| return ParallelBranchFinishedStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_run_id=workflow_execution_id, | |||
| data=ParallelBranchFinishedStreamResponse.Data( | |||
| parallel_id=event.parallel_id, | |||
| parallel_branch_id=event.parallel_start_node_id, | |||
| @@ -705,13 +673,16 @@ class WorkflowCycleManager: | |||
| ), | |||
| ) | |||
| def _workflow_iteration_start_to_stream_response( | |||
| self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent | |||
| def workflow_iteration_start_to_stream_response( | |||
| self, | |||
| *, | |||
| task_id: str, | |||
| workflow_execution_id: str, | |||
| event: QueueIterationStartEvent, | |||
| ) -> IterationNodeStartStreamResponse: | |||
| _ = session | |||
| return IterationNodeStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_run_id=workflow_execution_id, | |||
| data=IterationNodeStartStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| @@ -726,13 +697,16 @@ class WorkflowCycleManager: | |||
| ), | |||
| ) | |||
| def _workflow_iteration_next_to_stream_response( | |||
| self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent | |||
| def workflow_iteration_next_to_stream_response( | |||
| self, | |||
| *, | |||
| task_id: str, | |||
| workflow_execution_id: str, | |||
| event: QueueIterationNextEvent, | |||
| ) -> IterationNodeNextStreamResponse: | |||
| _ = session | |||
| return IterationNodeNextStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_run_id=workflow_execution_id, | |||
| data=IterationNodeNextStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| @@ -749,13 +723,16 @@ class WorkflowCycleManager: | |||
| ), | |||
| ) | |||
| def _workflow_iteration_completed_to_stream_response( | |||
| self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent | |||
| def workflow_iteration_completed_to_stream_response( | |||
| self, | |||
| *, | |||
| task_id: str, | |||
| workflow_execution_id: str, | |||
| event: QueueIterationCompletedEvent, | |||
| ) -> IterationNodeCompletedStreamResponse: | |||
| _ = session | |||
| return IterationNodeCompletedStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_run_id=workflow_execution_id, | |||
| data=IterationNodeCompletedStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| @@ -779,13 +756,12 @@ class WorkflowCycleManager: | |||
| ), | |||
| ) | |||
| def _workflow_loop_start_to_stream_response( | |||
| self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent | |||
| def workflow_loop_start_to_stream_response( | |||
| self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent | |||
| ) -> LoopNodeStartStreamResponse: | |||
| _ = session | |||
| return LoopNodeStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_run_id=workflow_execution_id, | |||
| data=LoopNodeStartStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| @@ -800,13 +776,16 @@ class WorkflowCycleManager: | |||
| ), | |||
| ) | |||
| def _workflow_loop_next_to_stream_response( | |||
| self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent | |||
| def workflow_loop_next_to_stream_response( | |||
| self, | |||
| *, | |||
| task_id: str, | |||
| workflow_execution_id: str, | |||
| event: QueueLoopNextEvent, | |||
| ) -> LoopNodeNextStreamResponse: | |||
| _ = session | |||
| return LoopNodeNextStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_run_id=workflow_execution_id, | |||
| data=LoopNodeNextStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| @@ -823,13 +802,16 @@ class WorkflowCycleManager: | |||
| ), | |||
| ) | |||
| def _workflow_loop_completed_to_stream_response( | |||
| self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent | |||
| def workflow_loop_completed_to_stream_response( | |||
| self, | |||
| *, | |||
| task_id: str, | |||
| workflow_execution_id: str, | |||
| event: QueueLoopCompletedEvent, | |||
| ) -> LoopNodeCompletedStreamResponse: | |||
| _ = session | |||
| return LoopNodeCompletedStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_run_id=workflow_execution_id, | |||
| data=LoopNodeCompletedStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| @@ -853,7 +835,7 @@ class WorkflowCycleManager: | |||
| ), | |||
| ) | |||
| def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]: | |||
| def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]: | |||
| """ | |||
| Fetch files from node outputs | |||
| :param outputs_dict: node outputs dict | |||
| @@ -910,20 +892,13 @@ class WorkflowCycleManager: | |||
| return None | |||
| def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: | |||
| if self._workflow_run and self._workflow_run.id == workflow_run_id: | |||
| cached_workflow_run = self._workflow_run | |||
| cached_workflow_run = session.merge(cached_workflow_run) | |||
| return cached_workflow_run | |||
| stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) | |||
| workflow_run = session.scalar(stmt) | |||
| if not workflow_run: | |||
| raise WorkflowRunNotFoundError(workflow_run_id) | |||
| self._workflow_run = workflow_run | |||
| return workflow_run | |||
| def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution: | |||
| execution = self._workflow_execution_repository.get(id) | |||
| if not execution: | |||
| raise WorkflowRunNotFoundError(id) | |||
| return execution | |||
| def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: | |||
| def handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: | |||
| """ | |||
| Handle agent log | |||
| :param task_id: task id | |||
| @@ -425,14 +425,14 @@ class WorkflowRun(Base): | |||
| status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded | |||
| outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") | |||
| error: Mapped[Optional[str]] = mapped_column(db.Text) | |||
| elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0")) | |||
| elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) | |||
| total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) | |||
| total_steps = db.Column(db.Integer, server_default=db.text("0")) | |||
| total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) | |||
| created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| finished_at = db.Column(db.DateTime) | |||
| exceptions_count = db.Column(db.Integer, server_default=db.text("0")) | |||
| created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) | |||
| exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) | |||
| @property | |||
| def created_by_account(self): | |||
| @@ -447,7 +447,7 @@ class WorkflowRun(Base): | |||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None | |||
| @property | |||
| def graph_dict(self): | |||
| def graph_dict(self) -> Mapping[str, Any]: | |||
| return json.loads(self.graph) if self.graph else {} | |||
| @property | |||
| @@ -752,12 +752,12 @@ class WorkflowAppLog(Base): | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID) | |||
| app_id: Mapped[str] = mapped_column(StringUUID) | |||
| workflow_id = db.Column(StringUUID, nullable=False) | |||
| workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| workflow_run_id: Mapped[str] = mapped_column(StringUUID) | |||
| created_from = db.Column(db.String(255), nullable=False) | |||
| created_by_role = db.Column(db.String(255), nullable=False) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| created_from: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @property | |||
| def workflow_run(self): | |||
| @@ -782,9 +782,11 @@ class ConversationVariable(Base): | |||
| id: Mapped[str] = mapped_column(StringUUID, primary_key=True) | |||
| conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True) | |||
| app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) | |||
| data = mapped_column(db.Text, nullable=False) | |||
| created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True) | |||
| updated_at = mapped_column( | |||
| data: Mapped[str] = mapped_column(db.Text, nullable=False) | |||
| created_at: Mapped[datetime] = mapped_column( | |||
| db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True | |||
| ) | |||
| updated_at: Mapped[datetime] = mapped_column( | |||
| db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() | |||
| ) | |||
| @@ -832,14 +834,14 @@ class WorkflowDraftVariable(Base): | |||
| # id is the unique identifier of a draft variable. | |||
| id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |||
| created_at = mapped_column( | |||
| created_at: Mapped[datetime] = mapped_column( | |||
| db.DateTime, | |||
| nullable=False, | |||
| default=_naive_utc_datetime, | |||
| server_default=func.current_timestamp(), | |||
| ) | |||
| updated_at = mapped_column( | |||
| updated_at: Mapped[datetime] = mapped_column( | |||
| db.DateTime, | |||
| nullable=False, | |||
| default=_naive_utc_datetime, | |||
| @@ -1,45 +1,73 @@ | |||
| import json | |||
| import time | |||
| from datetime import UTC, datetime | |||
| from unittest.mock import MagicMock, patch | |||
| from unittest.mock import MagicMock | |||
| import pytest | |||
| from sqlalchemy.orm import Session | |||
| from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | |||
| from core.app.entities.queue_entities import ( | |||
| QueueNodeFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| ) | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus | |||
| from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||
| from models.enums import CreatorUserRole | |||
| from models.model import AppMode | |||
| from models.workflow import ( | |||
| Workflow, | |||
| WorkflowNodeExecutionStatus, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| ) | |||
| @pytest.fixture | |||
| def mock_app_generate_entity(): | |||
| entity = MagicMock(spec=AdvancedChatAppGenerateEntity) | |||
| entity.inputs = {"query": "test query"} | |||
| entity.invoke_from = InvokeFrom.WEB_APP | |||
| # Create app_config as a separate mock | |||
| app_config = MagicMock() | |||
| app_config.tenant_id = "test-tenant-id" | |||
| app_config.app_id = "test-app-id" | |||
| entity.app_config = app_config | |||
| def real_app_generate_entity(): | |||
| additional_features = AppAdditionalFeatures( | |||
| file_upload=None, | |||
| opening_statement=None, | |||
| suggested_questions=[], | |||
| suggested_questions_after_answer=False, | |||
| show_retrieve_source=False, | |||
| more_like_this=False, | |||
| speech_to_text=False, | |||
| text_to_speech=None, | |||
| trace_config=None, | |||
| ) | |||
| app_config = WorkflowUIBasedAppConfig( | |||
| tenant_id="test-tenant-id", | |||
| app_id="test-app-id", | |||
| app_mode=AppMode.WORKFLOW, | |||
| additional_features=additional_features, | |||
| workflow_id="test-workflow-id", | |||
| ) | |||
| entity = AdvancedChatAppGenerateEntity( | |||
| task_id="test-task-id", | |||
| app_config=app_config, | |||
| inputs={"query": "test query"}, | |||
| files=[], | |||
| user_id="test-user-id", | |||
| stream=False, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| query="test query", | |||
| conversation_id="test-conversation-id", | |||
| ) | |||
| return entity | |||
| @pytest.fixture | |||
| def mock_workflow_system_variables(): | |||
| def real_workflow_system_variables(): | |||
| return { | |||
| SystemVariableKey.QUERY: "test query", | |||
| SystemVariableKey.CONVERSATION_ID: "test-conversation-id", | |||
| @@ -59,10 +87,23 @@ def mock_node_execution_repository(): | |||
| @pytest.fixture | |||
| def workflow_cycle_manager(mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository): | |||
| def mock_workflow_execution_repository(): | |||
| repo = MagicMock(spec=WorkflowExecutionRepository) | |||
| repo.get.return_value = None | |||
| return repo | |||
| @pytest.fixture | |||
| def workflow_cycle_manager( | |||
| real_app_generate_entity, | |||
| real_workflow_system_variables, | |||
| mock_workflow_execution_repository, | |||
| mock_node_execution_repository, | |||
| ): | |||
| return WorkflowCycleManager( | |||
| application_generate_entity=mock_app_generate_entity, | |||
| workflow_system_variables=mock_workflow_system_variables, | |||
| application_generate_entity=real_app_generate_entity, | |||
| workflow_system_variables=real_workflow_system_variables, | |||
| workflow_execution_repository=mock_workflow_execution_repository, | |||
| workflow_node_execution_repository=mock_node_execution_repository, | |||
| ) | |||
| @@ -74,121 +115,173 @@ def mock_session(): | |||
| @pytest.fixture | |||
| def mock_workflow(): | |||
| workflow = MagicMock(spec=Workflow) | |||
| def real_workflow(): | |||
| workflow = Workflow() | |||
| workflow.id = "test-workflow-id" | |||
| workflow.tenant_id = "test-tenant-id" | |||
| workflow.app_id = "test-app-id" | |||
| workflow.type = "chat" | |||
| workflow.version = "1.0" | |||
| workflow.graph = json.dumps({"nodes": [], "edges": []}) | |||
| graph_data = {"nodes": [], "edges": []} | |||
| workflow.graph = json.dumps(graph_data) | |||
| workflow.features = json.dumps({"file_upload": {"enabled": False}}) | |||
| workflow.created_by = "test-user-id" | |||
| workflow.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow._environment_variables = "{}" | |||
| workflow._conversation_variables = "{}" | |||
| return workflow | |||
| @pytest.fixture | |||
| def mock_workflow_run(): | |||
| workflow_run = MagicMock(spec=WorkflowRun) | |||
| def real_workflow_run(): | |||
| workflow_run = WorkflowRun() | |||
| workflow_run.id = "test-workflow-run-id" | |||
| workflow_run.tenant_id = "test-tenant-id" | |||
| workflow_run.app_id = "test-app-id" | |||
| workflow_run.workflow_id = "test-workflow-id" | |||
| workflow_run.sequence_number = 1 | |||
| workflow_run.type = "chat" | |||
| workflow_run.triggered_from = "app-run" | |||
| workflow_run.version = "1.0" | |||
| workflow_run.graph = json.dumps({"nodes": [], "edges": []}) | |||
| workflow_run.inputs = json.dumps({"query": "test query"}) | |||
| workflow_run.status = WorkflowRunStatus.RUNNING | |||
| workflow_run.outputs = json.dumps({"answer": "test answer"}) | |||
| workflow_run.created_by_role = CreatorUserRole.ACCOUNT | |||
| workflow_run.created_by = "test-user-id" | |||
| workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_run.inputs_dict = {"query": "test query"} | |||
| workflow_run.outputs_dict = {"answer": "test answer"} | |||
| return workflow_run | |||
| def test_init( | |||
| workflow_cycle_manager, mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository | |||
| workflow_cycle_manager, | |||
| real_app_generate_entity, | |||
| real_workflow_system_variables, | |||
| mock_workflow_execution_repository, | |||
| mock_node_execution_repository, | |||
| ): | |||
| """Test initialization of WorkflowCycleManager""" | |||
| assert workflow_cycle_manager._workflow_run is None | |||
| assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity | |||
| assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables | |||
| assert workflow_cycle_manager._application_generate_entity == real_app_generate_entity | |||
| assert workflow_cycle_manager._workflow_system_variables == real_workflow_system_variables | |||
| assert workflow_cycle_manager._workflow_execution_repository == mock_workflow_execution_repository | |||
| assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository | |||
| def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_workflow): | |||
| """Test _handle_workflow_run_start method""" | |||
| def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, real_workflow): | |||
| """Test handle_workflow_run_start method""" | |||
| # Mock session.scalar to return the workflow and max sequence | |||
| mock_session.scalar.side_effect = [mock_workflow, 5] | |||
| mock_session.scalar.side_effect = [real_workflow, 5] | |||
| # Call the method | |||
| workflow_run = workflow_cycle_manager._handle_workflow_run_start( | |||
| workflow_execution = workflow_cycle_manager.handle_workflow_run_start( | |||
| session=mock_session, | |||
| workflow_id="test-workflow-id", | |||
| user_id="test-user-id", | |||
| created_by_role=CreatorUserRole.ACCOUNT, | |||
| ) | |||
| # Verify the result | |||
| assert workflow_run.tenant_id == mock_workflow.tenant_id | |||
| assert workflow_run.app_id == mock_workflow.app_id | |||
| assert workflow_run.workflow_id == mock_workflow.id | |||
| assert workflow_run.sequence_number == 6 # max_sequence + 1 | |||
| assert workflow_run.status == WorkflowRunStatus.RUNNING | |||
| assert workflow_run.created_by_role == CreatorUserRole.ACCOUNT | |||
| assert workflow_run.created_by == "test-user-id" | |||
| # Verify session.add was called | |||
| mock_session.add.assert_called_once_with(workflow_run) | |||
| def test_handle_workflow_run_success(workflow_cycle_manager, mock_session, mock_workflow_run): | |||
| """Test _handle_workflow_run_success method""" | |||
| # Mock _get_workflow_run to return the mock_workflow_run | |||
| with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_run_success( | |||
| session=mock_session, | |||
| workflow_run_id="test-workflow-run-id", | |||
| start_at=time.perf_counter() - 10, # 10 seconds ago | |||
| total_tokens=100, | |||
| total_steps=5, | |||
| outputs={"answer": "test answer"}, | |||
| ) | |||
| # Verify the result | |||
| assert result == mock_workflow_run | |||
| assert result.status == WorkflowRunStatus.SUCCEEDED | |||
| assert result.outputs == json.dumps({"answer": "test answer"}) | |||
| assert result.total_tokens == 100 | |||
| assert result.total_steps == 5 | |||
| assert result.finished_at is not None | |||
| def test_handle_workflow_run_failed(workflow_cycle_manager, mock_session, mock_workflow_run): | |||
| """Test _handle_workflow_run_failed method""" | |||
| # Mock _get_workflow_run to return the mock_workflow_run | |||
| with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): | |||
| # Mock get_running_executions to return an empty list | |||
| workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = [] | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_run_failed( | |||
| session=mock_session, | |||
| workflow_run_id="test-workflow-run-id", | |||
| start_at=time.perf_counter() - 10, # 10 seconds ago | |||
| total_tokens=50, | |||
| total_steps=3, | |||
| status=WorkflowRunStatus.FAILED, | |||
| error="Test error message", | |||
| ) | |||
| # Verify the result | |||
| assert result == mock_workflow_run | |||
| assert result.status == WorkflowRunStatus.FAILED.value | |||
| assert result.error == "Test error message" | |||
| assert result.total_tokens == 50 | |||
| assert result.total_steps == 3 | |||
| assert result.finished_at is not None | |||
| def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run): | |||
| """Test _handle_node_execution_start method""" | |||
| assert workflow_execution.workflow_id == real_workflow.id | |||
| assert workflow_execution.sequence_number == 6 # max_sequence + 1 | |||
| # Verify the workflow_execution_repository.save was called | |||
| workflow_cycle_manager._workflow_execution_repository.save.assert_called_once_with(workflow_execution) | |||
| def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execution_repository): | |||
| """Test handle_workflow_run_success method""" | |||
| # Create a real WorkflowExecution | |||
| workflow_execution = WorkflowExecution( | |||
| id="test-workflow-run-id", | |||
| workflow_id="test-workflow-id", | |||
| workflow_version="1.0", | |||
| sequence_number=1, | |||
| type=WorkflowType.CHAT, | |||
| graph={"nodes": [], "edges": []}, | |||
| inputs={"query": "test query"}, | |||
| started_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution | |||
| workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution | |||
| # Call the method | |||
| result = workflow_cycle_manager.handle_workflow_run_success( | |||
| workflow_run_id="test-workflow-run-id", | |||
| total_tokens=100, | |||
| total_steps=5, | |||
| outputs={"answer": "test answer"}, | |||
| ) | |||
| # Verify the result | |||
| assert result == workflow_execution | |||
| assert result.status == WorkflowExecutionStatus.SUCCEEDED | |||
| assert result.outputs == {"answer": "test answer"} | |||
| assert result.total_tokens == 100 | |||
| assert result.total_steps == 5 | |||
| assert result.finished_at is not None | |||
| def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execution_repository): | |||
| """Test handle_workflow_run_failed method""" | |||
| # Create a real WorkflowExecution | |||
| workflow_execution = WorkflowExecution( | |||
| id="test-workflow-run-id", | |||
| workflow_id="test-workflow-id", | |||
| workflow_version="1.0", | |||
| sequence_number=1, | |||
| type=WorkflowType.CHAT, | |||
| graph={"nodes": [], "edges": []}, | |||
| inputs={"query": "test query"}, | |||
| started_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution | |||
| workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution | |||
| # Mock get_running_executions to return an empty list | |||
| workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = [] | |||
| # Call the method | |||
| result = workflow_cycle_manager.handle_workflow_run_failed( | |||
| workflow_run_id="test-workflow-run-id", | |||
| total_tokens=50, | |||
| total_steps=3, | |||
| status=WorkflowRunStatus.FAILED, | |||
| error_message="Test error message", | |||
| ) | |||
| # Verify the result | |||
| assert result == workflow_execution | |||
| assert result.status == WorkflowExecutionStatus(WorkflowRunStatus.FAILED.value) | |||
| assert result.error_message == "Test error message" | |||
| assert result.total_tokens == 50 | |||
| assert result.total_steps == 3 | |||
| assert result.finished_at is not None | |||
| def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execution_repository): | |||
| """Test handle_node_execution_start method""" | |||
| # Create a real WorkflowExecution | |||
| workflow_execution = WorkflowExecution( | |||
| id="test-workflow-execution-id", | |||
| workflow_id="test-workflow-id", | |||
| workflow_version="1.0", | |||
| sequence_number=1, | |||
| type=WorkflowType.CHAT, | |||
| graph={"nodes": [], "edges": []}, | |||
| inputs={"query": "test query"}, | |||
| started_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution | |||
| workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution | |||
| # Create a mock event | |||
| event = MagicMock(spec=QueueNodeStartedEvent) | |||
| event.node_execution_id = "test-node-execution-id" | |||
| @@ -207,129 +300,171 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run): | |||
| event.in_loop_id = "test-loop-id" | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_node_execution_start( | |||
| workflow_run=mock_workflow_run, | |||
| result = workflow_cycle_manager.handle_node_execution_start( | |||
| workflow_execution_id=workflow_execution.id, | |||
| event=event, | |||
| ) | |||
| # Verify the result | |||
| # NodeExecution doesn't have tenant_id attribute, it's handled at repository level | |||
| # assert result.tenant_id == mock_workflow_run.tenant_id | |||
| # assert result.app_id == mock_workflow_run.app_id | |||
| assert result.workflow_id == mock_workflow_run.workflow_id | |||
| assert result.workflow_run_id == mock_workflow_run.id | |||
| assert result.workflow_id == workflow_execution.workflow_id | |||
| assert result.workflow_run_id == workflow_execution.id | |||
| assert result.node_execution_id == event.node_execution_id | |||
| assert result.node_id == event.node_id | |||
| assert result.node_type == event.node_type | |||
| assert result.title == event.node_data.title | |||
| assert result.status == WorkflowNodeExecutionStatus.RUNNING.value | |||
| # NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level | |||
| # assert result.created_by_role == mock_workflow_run.created_by_role | |||
| # assert result.created_by == mock_workflow_run.created_by | |||
| assert result.status == NodeExecutionStatus.RUNNING | |||
| # Verify save was called | |||
| workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result) | |||
| def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run): | |||
| """Test _get_workflow_run method""" | |||
| # Mock session.scalar to return the workflow run | |||
| mock_session.scalar.return_value = mock_workflow_run | |||
| def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_workflow_execution_repository): | |||
| """Test _get_workflow_execution_or_raise_error method""" | |||
| # Create a real WorkflowExecution | |||
| # Call the method | |||
| result = workflow_cycle_manager._get_workflow_run( | |||
| session=mock_session, | |||
| workflow_run_id="test-workflow-run-id", | |||
| workflow_execution = WorkflowExecution( | |||
| id="test-workflow-run-id", | |||
| workflow_id="test-workflow-id", | |||
| workflow_version="1.0", | |||
| sequence_number=1, | |||
| type=WorkflowType.CHAT, | |||
| graph={"nodes": [], "edges": []}, | |||
| inputs={"query": "test query"}, | |||
| started_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| # Mock the repository get method to return the real execution | |||
| workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution | |||
| # Call the method | |||
| result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id") | |||
| # Verify the result | |||
| assert result == mock_workflow_run | |||
| assert workflow_cycle_manager._workflow_run == mock_workflow_run | |||
| assert result == workflow_execution | |||
| # Test error case | |||
| workflow_cycle_manager._workflow_execution_repository.get.return_value = None | |||
| # Expect an error when execution is not found | |||
| with pytest.raises(ValueError): | |||
| workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id") | |||
| def test_handle_workflow_node_execution_success(workflow_cycle_manager): | |||
| """Test _handle_workflow_node_execution_success method""" | |||
| """Test handle_workflow_node_execution_success method""" | |||
| # Create a mock event | |||
| event = MagicMock(spec=QueueNodeSucceededEvent) | |||
| event.node_execution_id = "test-node-execution-id" | |||
| event.inputs = {"input": "test input"} | |||
| event.process_data = {"process": "test process"} | |||
| event.outputs = {"output": "test output"} | |||
| event.execution_metadata = {"metadata": "test metadata"} | |||
| event.execution_metadata = {NodeRunMetadataKey.TOTAL_TOKENS: 100} | |||
| event.start_at = datetime.now(UTC).replace(tzinfo=None) | |||
| # Create a mock node execution | |||
| node_execution = MagicMock() | |||
| node_execution.node_execution_id = "test-node-execution-id" | |||
| # Create a real node execution | |||
| node_execution = NodeExecution( | |||
| id="test-node-execution-record-id", | |||
| node_execution_id="test-node-execution-id", | |||
| workflow_id="test-workflow-id", | |||
| workflow_run_id="test-workflow-run-id", | |||
| index=1, | |||
| node_id="test-node-id", | |||
| node_type=NodeType.LLM, | |||
| title="Test Node", | |||
| created_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| # Mock the repository to return the node execution | |||
| workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_node_execution_success( | |||
| result = workflow_cycle_manager.handle_workflow_node_execution_success( | |||
| event=event, | |||
| ) | |||
| # Verify the result | |||
| assert result == node_execution | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value | |||
| assert result.status == NodeExecutionStatus.SUCCEEDED | |||
| # Verify save was called | |||
| workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) | |||
| def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run): | |||
| """Test _handle_workflow_run_partial_success method""" | |||
| # Mock _get_workflow_run to return the mock_workflow_run | |||
| with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_run_partial_success( | |||
| session=mock_session, | |||
| workflow_run_id="test-workflow-run-id", | |||
| start_at=time.perf_counter() - 10, # 10 seconds ago | |||
| total_tokens=75, | |||
| total_steps=4, | |||
| outputs={"partial_answer": "test partial answer"}, | |||
| exceptions_count=2, | |||
| ) | |||
| # Verify the result | |||
| assert result == mock_workflow_run | |||
| assert result.status == WorkflowRunStatus.PARTIAL_SUCCEEDED.value | |||
| assert result.outputs == json.dumps({"partial_answer": "test partial answer"}) | |||
| assert result.total_tokens == 75 | |||
| assert result.total_steps == 4 | |||
| assert result.exceptions_count == 2 | |||
| assert result.finished_at is not None | |||
| def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workflow_execution_repository): | |||
| """Test handle_workflow_run_partial_success method""" | |||
| # Create a real WorkflowExecution | |||
| workflow_execution = WorkflowExecution( | |||
| id="test-workflow-run-id", | |||
| workflow_id="test-workflow-id", | |||
| workflow_version="1.0", | |||
| sequence_number=1, | |||
| type=WorkflowType.CHAT, | |||
| graph={"nodes": [], "edges": []}, | |||
| inputs={"query": "test query"}, | |||
| started_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution | |||
| workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution | |||
| # Call the method | |||
| result = workflow_cycle_manager.handle_workflow_run_partial_success( | |||
| workflow_run_id="test-workflow-run-id", | |||
| total_tokens=75, | |||
| total_steps=4, | |||
| outputs={"partial_answer": "test partial answer"}, | |||
| exceptions_count=2, | |||
| ) | |||
| # Verify the result | |||
| assert result == workflow_execution | |||
| assert result.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED | |||
| assert result.outputs == {"partial_answer": "test partial answer"} | |||
| assert result.total_tokens == 75 | |||
| assert result.total_steps == 4 | |||
| assert result.exceptions_count == 2 | |||
| assert result.finished_at is not None | |||
| def test_handle_workflow_node_execution_failed(workflow_cycle_manager): | |||
| """Test _handle_workflow_node_execution_failed method""" | |||
| """Test handle_workflow_node_execution_failed method""" | |||
| # Create a mock event | |||
| event = MagicMock(spec=QueueNodeFailedEvent) | |||
| event.node_execution_id = "test-node-execution-id" | |||
| event.inputs = {"input": "test input"} | |||
| event.process_data = {"process": "test process"} | |||
| event.outputs = {"output": "test output"} | |||
| event.execution_metadata = {"metadata": "test metadata"} | |||
| event.execution_metadata = {NodeRunMetadataKey.TOTAL_TOKENS: 100} | |||
| event.start_at = datetime.now(UTC).replace(tzinfo=None) | |||
| event.error = "Test error message" | |||
| # Create a mock node execution | |||
| node_execution = MagicMock() | |||
| node_execution.node_execution_id = "test-node-execution-id" | |||
| # Create a real node execution | |||
| node_execution = NodeExecution( | |||
| id="test-node-execution-record-id", | |||
| node_execution_id="test-node-execution-id", | |||
| workflow_id="test-workflow-id", | |||
| workflow_run_id="test-workflow-run-id", | |||
| index=1, | |||
| node_id="test-node-id", | |||
| node_type=NodeType.LLM, | |||
| title="Test Node", | |||
| created_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| # Mock the repository to return the node execution | |||
| workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_node_execution_failed( | |||
| result = workflow_cycle_manager.handle_workflow_node_execution_failed( | |||
| event=event, | |||
| ) | |||
| # Verify the result | |||
| assert result == node_execution | |||
| assert result.status == WorkflowNodeExecutionStatus.FAILED.value | |||
| assert result.status == NodeExecutionStatus.FAILED | |||
| assert result.error == "Test error message" | |||
| # Verify save was called | |||