Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: -LAN- <laipz8200@outlook.com>tags/0.8.0
| @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): | |||
| CURRENT_VERSION: str = Field( | |||
| description="Dify version", | |||
| default="0.7.3", | |||
| default="0.8.0", | |||
| ) | |||
| COMMIT_SHA: str = Field( | |||
| @@ -4,12 +4,10 @@ import os | |||
| import threading | |||
| import uuid | |||
| from collections.abc import Generator | |||
| from typing import Literal, Union, overload | |||
| from typing import Any, Literal, Optional, Union, overload | |||
| from flask import Flask, current_app | |||
| from pydantic import ValidationError | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| import contexts | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| @@ -20,20 +18,15 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom | |||
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import ( | |||
| AdvancedChatAppGenerateEntity, | |||
| InvokeFrom, | |||
| ) | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | |||
| from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import App, Conversation, EndUser, Message | |||
| from models.workflow import ConversationVariable, Workflow | |||
| from models.workflow import Workflow | |||
| logger = logging.getLogger(__name__) | |||
| @@ -60,13 +53,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| ) -> dict: ... | |||
| def generate( | |||
| self, app_model: App, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: dict, | |||
| invoke_from: InvokeFrom, | |||
| stream: bool = True, | |||
| ): | |||
| self, | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: dict, | |||
| invoke_from: InvokeFrom, | |||
| stream: bool = True, | |||
| ) -> dict[str, Any] | Generator[str, Any, None]: | |||
| """ | |||
| Generate App response. | |||
| @@ -154,7 +148,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| node_id: str, | |||
| user: Account, | |||
| args: dict, | |||
| stream: bool = True): | |||
| stream: bool = True) \ | |||
| -> dict[str, Any] | Generator[str, Any, None]: | |||
| """ | |||
| Generate App response. | |||
| @@ -171,16 +166,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| if args.get('inputs') is None: | |||
| raise ValueError('inputs is required') | |||
| extras = { | |||
| "auto_generate_conversation_name": False | |||
| } | |||
| # get conversation | |||
| conversation = None | |||
| conversation_id = args.get('conversation_id') | |||
| if conversation_id: | |||
| conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) | |||
| # convert to app config | |||
| app_config = AdvancedChatAppConfigManager.get_app_config( | |||
| app_model=app_model, | |||
| @@ -191,14 +176,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| application_generate_entity = AdvancedChatAppGenerateEntity( | |||
| task_id=str(uuid.uuid4()), | |||
| app_config=app_config, | |||
| conversation_id=conversation.id if conversation else None, | |||
| conversation_id=None, | |||
| inputs={}, | |||
| query='', | |||
| files=[], | |||
| user_id=user.id, | |||
| stream=stream, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| extras=extras, | |||
| extras={ | |||
| "auto_generate_conversation_name": False | |||
| }, | |||
| single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( | |||
| node_id=node_id, | |||
| inputs=args['inputs'] | |||
| @@ -211,17 +198,28 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| user=user, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| application_generate_entity=application_generate_entity, | |||
| conversation=conversation, | |||
| conversation=None, | |||
| stream=stream | |||
| ) | |||
| def _generate(self, *, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| invoke_from: InvokeFrom, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| conversation: Conversation | None = None, | |||
| stream: bool = True): | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| invoke_from: InvokeFrom, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| conversation: Optional[Conversation] = None, | |||
| stream: bool = True) \ | |||
| -> dict[str, Any] | Generator[str, Any, None]: | |||
| """ | |||
| Generate App response. | |||
| :param workflow: Workflow | |||
| :param user: account or end user | |||
| :param invoke_from: invoke from source | |||
| :param application_generate_entity: application generate entity | |||
| :param conversation: conversation | |||
| :param stream: is stream | |||
| """ | |||
| is_first_conversation = False | |||
| if not conversation: | |||
| is_first_conversation = True | |||
| @@ -236,7 +234,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| # update conversation features | |||
| conversation.override_model_configs = workflow.features | |||
| db.session.commit() | |||
| # db.session.refresh(conversation) | |||
| db.session.refresh(conversation) | |||
| # init queue manager | |||
| queue_manager = MessageBasedAppQueueManager( | |||
| @@ -248,67 +246,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| message_id=message.id | |||
| ) | |||
| # Init conversation variables | |||
| stmt = select(ConversationVariable).where( | |||
| ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id | |||
| ) | |||
| with Session(db.engine) as session: | |||
| conversation_variables = session.scalars(stmt).all() | |||
| if not conversation_variables: | |||
| # Create conversation variables if they don't exist. | |||
| conversation_variables = [ | |||
| ConversationVariable.from_variable( | |||
| app_id=conversation.app_id, conversation_id=conversation.id, variable=variable | |||
| ) | |||
| for variable in workflow.conversation_variables | |||
| ] | |||
| session.add_all(conversation_variables) | |||
| # Convert database entities to variables. | |||
| conversation_variables = [item.to_variable() for item in conversation_variables] | |||
| session.commit() | |||
| # Increment dialogue count. | |||
| conversation.dialogue_count += 1 | |||
| conversation_id = conversation.id | |||
| conversation_dialogue_count = conversation.dialogue_count | |||
| db.session.commit() | |||
| db.session.refresh(conversation) | |||
| inputs = application_generate_entity.inputs | |||
| query = application_generate_entity.query | |||
| files = application_generate_entity.files | |||
| user_id = None | |||
| if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() | |||
| if end_user: | |||
| user_id = end_user.session_id | |||
| else: | |||
| user_id = application_generate_entity.user_id | |||
| # Create a variable pool. | |||
| system_inputs = { | |||
| SystemVariableKey.QUERY: query, | |||
| SystemVariableKey.FILES: files, | |||
| SystemVariableKey.CONVERSATION_ID: conversation_id, | |||
| SystemVariableKey.USER_ID: user_id, | |||
| SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, | |||
| } | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| user_inputs=inputs, | |||
| environment_variables=workflow.environment_variables, | |||
| conversation_variables=conversation_variables, | |||
| ) | |||
| contexts.workflow_variable_pool.set(variable_pool) | |||
| # new thread | |||
| worker_thread = threading.Thread(target=self._generate_worker, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'flask_app': current_app._get_current_object(), # type: ignore | |||
| 'application_generate_entity': application_generate_entity, | |||
| 'queue_manager': queue_manager, | |||
| 'conversation_id': conversation.id, | |||
| 'message_id': message.id, | |||
| 'context': contextvars.copy_context(), | |||
| }) | |||
| @@ -334,6 +277,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| def _generate_worker(self, flask_app: Flask, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| conversation_id: str, | |||
| message_id: str, | |||
| context: contextvars.Context) -> None: | |||
| """ | |||
| @@ -349,28 +293,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| var.set(val) | |||
| with flask_app.app_context(): | |||
| try: | |||
| runner = AdvancedChatAppRunner() | |||
| if application_generate_entity.single_iteration_run: | |||
| single_iteration_run = application_generate_entity.single_iteration_run | |||
| runner.single_iteration_run( | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| workflow_id=application_generate_entity.app_config.workflow_id, | |||
| queue_manager=queue_manager, | |||
| inputs=single_iteration_run.inputs, | |||
| node_id=single_iteration_run.node_id, | |||
| user_id=application_generate_entity.user_id | |||
| ) | |||
| else: | |||
| # get message | |||
| message = self._get_message(message_id) | |||
| # chatbot app | |||
| runner = AdvancedChatAppRunner() | |||
| runner.run( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| message=message | |||
| ) | |||
| # get conversation and message | |||
| conversation = self._get_conversation(conversation_id) | |||
| message = self._get_message(message_id) | |||
| # chatbot app | |||
| runner = AdvancedChatAppRunner( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| conversation=conversation, | |||
| message=message | |||
| ) | |||
| runner.run() | |||
| except GenerateTaskStoppedException: | |||
| pass | |||
| except InvokeAuthorizationError: | |||
| @@ -1,49 +1,67 @@ | |||
| import logging | |||
| import os | |||
| import time | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional, cast | |||
| from typing import Any, cast | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig | |||
| from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.apps.base_app_runner import AppRunner | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner | |||
| from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback | |||
| from core.app.entities.app_invoke_entities import ( | |||
| AdvancedChatAppGenerateEntity, | |||
| InvokeFrom, | |||
| ) | |||
| from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent | |||
| from core.app.entities.queue_entities import ( | |||
| QueueAnnotationReplyEvent, | |||
| QueueStopEvent, | |||
| QueueTextChunkEvent, | |||
| ) | |||
| from core.moderation.base import ModerationException | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.workflow_engine_manager import WorkflowEngineManager | |||
| from core.workflow.entities.node_entities import UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models import App, Message, Workflow | |||
| from models.model import App, Conversation, EndUser, Message | |||
| from models.workflow import ConversationVariable, WorkflowType | |||
| logger = logging.getLogger(__name__) | |||
| class AdvancedChatAppRunner(AppRunner): | |||
| class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| """ | |||
| AdvancedChat Application Runner | |||
| """ | |||
| def run( | |||
| self, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| message: Message, | |||
| def __init__( | |||
| self, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| message: Message | |||
| ) -> None: | |||
| """ | |||
| Run application | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: application queue manager | |||
| :param conversation: conversation | |||
| :param message: message | |||
| """ | |||
| super().__init__(queue_manager) | |||
| self.application_generate_entity = application_generate_entity | |||
| self.conversation = conversation | |||
| self.message = message | |||
| def run(self) -> None: | |||
| """ | |||
| Run application | |||
| :return: | |||
| """ | |||
| app_config = application_generate_entity.app_config | |||
| app_config = self.application_generate_entity.app_config | |||
| app_config = cast(AdvancedChatAppConfig, app_config) | |||
| app_record = db.session.query(App).filter(App.id == app_config.app_id).first() | |||
| @@ -54,101 +72,133 @@ class AdvancedChatAppRunner(AppRunner): | |||
| if not workflow: | |||
| raise ValueError('Workflow not initialized') | |||
| inputs = application_generate_entity.inputs | |||
| query = application_generate_entity.query | |||
| user_id = None | |||
| if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() | |||
| if end_user: | |||
| user_id = end_user.session_id | |||
| else: | |||
| user_id = self.application_generate_entity.user_id | |||
| # moderation | |||
| if self.handle_input_moderation( | |||
| queue_manager=queue_manager, | |||
| app_record=app_record, | |||
| app_generate_entity=application_generate_entity, | |||
| inputs=inputs, | |||
| query=query, | |||
| message_id=message.id, | |||
| ): | |||
| return | |||
| workflow_callbacks: list[WorkflowCallback] = [] | |||
| if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| # annotation reply | |||
| if self.handle_annotation_reply( | |||
| app_record=app_record, | |||
| message=message, | |||
| query=query, | |||
| queue_manager=queue_manager, | |||
| app_generate_entity=application_generate_entity, | |||
| ): | |||
| return | |||
| if self.application_generate_entity.single_iteration_run: | |||
| # if only single iteration run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | |||
| workflow=workflow, | |||
| node_id=self.application_generate_entity.single_iteration_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_iteration_run.inputs | |||
| ) | |||
| else: | |||
| inputs = self.application_generate_entity.inputs | |||
| query = self.application_generate_entity.query | |||
| files = self.application_generate_entity.files | |||
| db.session.close() | |||
| # moderation | |||
| if self.handle_input_moderation( | |||
| app_record=app_record, | |||
| app_generate_entity=self.application_generate_entity, | |||
| inputs=inputs, | |||
| query=query, | |||
| message_id=self.message.id | |||
| ): | |||
| return | |||
| workflow_callbacks: list[WorkflowCallback] = [ | |||
| WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) | |||
| ] | |||
| # annotation reply | |||
| if self.handle_annotation_reply( | |||
| app_record=app_record, | |||
| message=self.message, | |||
| query=query, | |||
| app_generate_entity=self.application_generate_entity | |||
| ): | |||
| return | |||
| if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| # Init conversation variables | |||
| stmt = select(ConversationVariable).where( | |||
| ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id | |||
| ) | |||
| with Session(db.engine) as session: | |||
| conversation_variables = session.scalars(stmt).all() | |||
| if not conversation_variables: | |||
| # Create conversation variables if they don't exist. | |||
| conversation_variables = [ | |||
| ConversationVariable.from_variable( | |||
| app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable | |||
| ) | |||
| for variable in workflow.conversation_variables | |||
| ] | |||
| session.add_all(conversation_variables) | |||
| # Convert database entities to variables. | |||
| conversation_variables = [item.to_variable() for item in conversation_variables] | |||
| # RUN WORKFLOW | |||
| workflow_engine_manager = WorkflowEngineManager() | |||
| workflow_engine_manager.run_workflow( | |||
| workflow=workflow, | |||
| user_id=application_generate_entity.user_id, | |||
| user_from=UserFrom.ACCOUNT | |||
| if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] | |||
| else UserFrom.END_USER, | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| callbacks=workflow_callbacks, | |||
| call_depth=application_generate_entity.call_depth, | |||
| ) | |||
| session.commit() | |||
| def single_iteration_run( | |||
| self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str | |||
| ) -> None: | |||
| """ | |||
| Single iteration run | |||
| """ | |||
| app_record = db.session.query(App).filter(App.id == app_id).first() | |||
| if not app_record: | |||
| raise ValueError('App not found') | |||
| # Increment dialogue count. | |||
| self.conversation.dialogue_count += 1 | |||
| workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) | |||
| if not workflow: | |||
| raise ValueError('Workflow not initialized') | |||
| conversation_dialogue_count = self.conversation.dialogue_count | |||
| db.session.commit() | |||
| workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] | |||
| # Create a variable pool. | |||
| system_inputs = { | |||
| SystemVariableKey.QUERY: query, | |||
| SystemVariableKey.FILES: files, | |||
| SystemVariableKey.CONVERSATION_ID: self.conversation.id, | |||
| SystemVariableKey.USER_ID: user_id, | |||
| SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, | |||
| } | |||
| workflow_engine_manager = WorkflowEngineManager() | |||
| workflow_engine_manager.single_step_run_iteration_workflow_node( | |||
| workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks | |||
| # init variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| user_inputs=inputs, | |||
| environment_variables=workflow.environment_variables, | |||
| conversation_variables=conversation_variables, | |||
| ) | |||
| # init graph | |||
| graph = self._init_graph(graph_config=workflow.graph_dict) | |||
| db.session.close() | |||
| # RUN WORKFLOW | |||
| workflow_entry = WorkflowEntry( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=workflow.app_id, | |||
| workflow_id=workflow.id, | |||
| workflow_type=WorkflowType.value_of(workflow.type), | |||
| graph=graph, | |||
| graph_config=workflow.graph_dict, | |||
| user_id=self.application_generate_entity.user_id, | |||
| user_from=( | |||
| UserFrom.ACCOUNT | |||
| if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] | |||
| else UserFrom.END_USER | |||
| ), | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| call_depth=self.application_generate_entity.call_depth, | |||
| variable_pool=variable_pool, | |||
| ) | |||
| def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: | |||
| """ | |||
| Get workflow | |||
| """ | |||
| # fetch workflow by workflow_id | |||
| workflow = ( | |||
| db.session.query(Workflow) | |||
| .filter( | |||
| Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id | |||
| ) | |||
| .first() | |||
| generator = workflow_entry.run( | |||
| callbacks=workflow_callbacks, | |||
| ) | |||
| # return workflow | |||
| return workflow | |||
| for event in generator: | |||
| self._handle_event(workflow_entry, event) | |||
| def handle_input_moderation( | |||
| self, | |||
| queue_manager: AppQueueManager, | |||
| app_record: App, | |||
| app_generate_entity: AdvancedChatAppGenerateEntity, | |||
| inputs: Mapping[str, Any], | |||
| query: str, | |||
| message_id: str, | |||
| self, | |||
| app_record: App, | |||
| app_generate_entity: AdvancedChatAppGenerateEntity, | |||
| inputs: Mapping[str, Any], | |||
| query: str, | |||
| message_id: str | |||
| ) -> bool: | |||
| """ | |||
| Handle input moderation | |||
| :param queue_manager: application queue manager | |||
| :param app_record: app record | |||
| :param app_generate_entity: application generate entity | |||
| :param inputs: inputs | |||
| @@ -167,30 +217,23 @@ class AdvancedChatAppRunner(AppRunner): | |||
| message_id=message_id, | |||
| ) | |||
| except ModerationException as e: | |||
| self._stream_output( | |||
| queue_manager=queue_manager, | |||
| self._complete_with_stream_output( | |||
| text=str(e), | |||
| stream=app_generate_entity.stream, | |||
| stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION, | |||
| stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION | |||
| ) | |||
| return True | |||
| return False | |||
| def handle_annotation_reply( | |||
| self, | |||
| app_record: App, | |||
| message: Message, | |||
| query: str, | |||
| queue_manager: AppQueueManager, | |||
| app_generate_entity: AdvancedChatAppGenerateEntity, | |||
| ) -> bool: | |||
| def handle_annotation_reply(self, app_record: App, | |||
| message: Message, | |||
| query: str, | |||
| app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: | |||
| """ | |||
| Handle annotation reply | |||
| :param app_record: app record | |||
| :param message: message | |||
| :param query: query | |||
| :param queue_manager: application queue manager | |||
| :param app_generate_entity: application generate entity | |||
| """ | |||
| # annotation reply | |||
| @@ -203,37 +246,32 @@ class AdvancedChatAppRunner(AppRunner): | |||
| ) | |||
| if annotation_reply: | |||
| queue_manager.publish( | |||
| QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER | |||
| self._publish_event( | |||
| QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id) | |||
| ) | |||
| self._stream_output( | |||
| queue_manager=queue_manager, | |||
| self._complete_with_stream_output( | |||
| text=annotation_reply.content, | |||
| stream=app_generate_entity.stream, | |||
| stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY, | |||
| stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY | |||
| ) | |||
| return True | |||
| return False | |||
| def _stream_output( | |||
| self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy | |||
| ) -> None: | |||
| def _complete_with_stream_output(self, | |||
| text: str, | |||
| stopped_by: QueueStopEvent.StopBy) -> None: | |||
| """ | |||
| Direct output | |||
| :param queue_manager: application queue manager | |||
| :param text: text | |||
| :param stream: stream | |||
| :return: | |||
| """ | |||
| if stream: | |||
| index = 0 | |||
| for token in text: | |||
| queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER) | |||
| index += 1 | |||
| time.sleep(0.01) | |||
| else: | |||
| queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER) | |||
| self._publish_event( | |||
| QueueTextChunkEvent( | |||
| text=text | |||
| ) | |||
| ) | |||
| queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER) | |||
| self._publish_event( | |||
| QueueStopEvent(stopped_by=stopped_by) | |||
| ) | |||
| @@ -2,9 +2,8 @@ import json | |||
| import logging | |||
| import time | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional, Union, cast | |||
| from typing import Any, Optional, Union | |||
| import contexts | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| @@ -22,6 +21,9 @@ from core.app.entities.queue_entities import ( | |||
| QueueNodeFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| QueueParallelBranchRunStartedEvent, | |||
| QueueParallelBranchRunSucceededEvent, | |||
| QueuePingEvent, | |||
| QueueRetrieverResourcesEvent, | |||
| QueueStopEvent, | |||
| @@ -31,34 +33,28 @@ from core.app.entities.queue_entities import ( | |||
| QueueWorkflowSucceededEvent, | |||
| ) | |||
| from core.app.entities.task_entities import ( | |||
| AdvancedChatTaskState, | |||
| ChatbotAppBlockingResponse, | |||
| ChatbotAppStreamResponse, | |||
| ChatflowStreamGenerateRoute, | |||
| ErrorStreamResponse, | |||
| MessageAudioEndStreamResponse, | |||
| MessageAudioStreamResponse, | |||
| MessageEndStreamResponse, | |||
| StreamResponse, | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | |||
| from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage | |||
| from core.file.file_obj import FileVar | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.answer.answer_node import AnswerNode | |||
| from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import Conversation, EndUser, Message | |||
| from models.workflow import ( | |||
| Workflow, | |||
| WorkflowNodeExecution, | |||
| WorkflowRunStatus, | |||
| ) | |||
| @@ -69,16 +65,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| """ | |||
| AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. | |||
| """ | |||
| _task_state: AdvancedChatTaskState | |||
| _task_state: WorkflowTaskState | |||
| _application_generate_entity: AdvancedChatAppGenerateEntity | |||
| _workflow: Workflow | |||
| _user: Union[Account, EndUser] | |||
| # Deprecated | |||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||
| _iteration_nested_relations: dict[str, list[str]] | |||
| def __init__( | |||
| self, application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| self, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| workflow: Workflow, | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| @@ -106,7 +101,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| self._workflow = workflow | |||
| self._conversation = conversation | |||
| self._message = message | |||
| # Deprecated | |||
| self._workflow_system_variables = { | |||
| SystemVariableKey.QUERY: message.query, | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| @@ -114,12 +108,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| SystemVariableKey.USER_ID: user_id, | |||
| } | |||
| self._task_state = AdvancedChatTaskState( | |||
| usage=LLMUsage.empty_usage() | |||
| ) | |||
| self._task_state = WorkflowTaskState() | |||
| self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) | |||
| self._stream_generate_routes = self._get_stream_generate_routes() | |||
| self._conversation_name_generate_thread = None | |||
| def process(self): | |||
| @@ -140,6 +130,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| generator = self._wrapper_process_stream_response( | |||
| trace_manager=self._application_generate_entity.trace_manager | |||
| ) | |||
| if self._stream: | |||
| return self._to_stream_response(generator) | |||
| else: | |||
| @@ -199,17 +190,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ | |||
| Generator[StreamResponse, None, None]: | |||
| publisher = None | |||
| tts_publisher = None | |||
| task_id = self._application_generate_entity.task_id | |||
| tenant_id = self._application_generate_entity.app_config.tenant_id | |||
| features_dict = self._workflow.features_dict | |||
| if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ | |||
| 'text_to_speech'].get('autoPlay') == 'enabled': | |||
| publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) | |||
| for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): | |||
| tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) | |||
| for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): | |||
| while True: | |||
| audio_response = self._listenAudioMsg(publisher, task_id=task_id) | |||
| audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) | |||
| if audio_response: | |||
| yield audio_response | |||
| else: | |||
| @@ -220,9 +212,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| # timeout | |||
| while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: | |||
| try: | |||
| if not publisher: | |||
| if not tts_publisher: | |||
| break | |||
| audio_trunk = publisher.checkAndGetAudio() | |||
| audio_trunk = tts_publisher.checkAndGetAudio() | |||
| if audio_trunk is None: | |||
| # release cpu | |||
| # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) | |||
| @@ -240,34 +232,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| def _process_stream_response( | |||
| self, | |||
| publisher: AppGeneratorTTSPublisher, | |||
| tts_publisher: Optional[AppGeneratorTTSPublisher] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| """ | |||
| Process stream response. | |||
| :return: | |||
| """ | |||
| for message in self._queue_manager.listen(): | |||
| if (message.event | |||
| and getattr(message.event, 'metadata', None) | |||
| and message.event.metadata.get('is_answer_previous_node', False) | |||
| and publisher): | |||
| publisher.publish(message=message) | |||
| elif (hasattr(message.event, 'execution_metadata') | |||
| and message.event.execution_metadata | |||
| and message.event.execution_metadata.get('is_answer_previous_node', False) | |||
| and publisher): | |||
| publisher.publish(message=message) | |||
| event = message.event | |||
| if isinstance(event, QueueErrorEvent): | |||
| # init fake graph runtime state | |||
| graph_runtime_state = None | |||
| workflow_run = None | |||
| for queue_message in self._queue_manager.listen(): | |||
| event = queue_message.event | |||
| if isinstance(event, QueuePingEvent): | |||
| yield self._ping_stream_response() | |||
| elif isinstance(event, QueueErrorEvent): | |||
| err = self._handle_error(event, self._message) | |||
| yield self._error_to_stream_response(err) | |||
| break | |||
| elif isinstance(event, QueueWorkflowStartedEvent): | |||
| workflow_run = self._handle_workflow_start() | |||
| # override graph runtime state | |||
| graph_runtime_state = event.graph_runtime_state | |||
| # init workflow run | |||
| workflow_run = self._handle_workflow_run_start() | |||
| self._message = db.session.query(Message).filter(Message.id == self._message.id).first() | |||
| self._refetch_message() | |||
| self._message.workflow_run_id = workflow_run.id | |||
| db.session.commit() | |||
| @@ -279,133 +271,242 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| workflow_run=workflow_run | |||
| ) | |||
| elif isinstance(event, QueueNodeStartedEvent): | |||
| workflow_node_execution = self._handle_node_start(event) | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| # search stream_generate_routes if node id is answer start at node | |||
| if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: | |||
| self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] | |||
| # reset current route position to 0 | |||
| self._task_state.current_stream_generate_state.current_route_position = 0 | |||
| workflow_node_execution = self._handle_node_execution_start( | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| # generate stream outputs when node started | |||
| yield from self._generate_stream_outputs_when_node_started() | |||
| response = self._workflow_node_start_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution | |||
| ) | |||
| yield self._workflow_node_start_to_stream_response( | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueNodeSucceededEvent): | |||
| workflow_node_execution = self._handle_workflow_node_execution_success(event) | |||
| response = self._workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution | |||
| ) | |||
| elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): | |||
| workflow_node_execution = self._handle_node_finished(event) | |||
| # stream outputs when node finished | |||
| generator = self._generate_stream_outputs_when_node_finished() | |||
| if generator: | |||
| yield from generator | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueNodeFailedEvent): | |||
| workflow_node_execution = self._handle_workflow_node_execution_failed(event) | |||
| yield self._workflow_node_finish_to_stream_response( | |||
| response = self._workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution | |||
| ) | |||
| if isinstance(event, QueueNodeFailedEvent): | |||
| yield from self._handle_iteration_exception( | |||
| task_id=self._application_generate_entity.task_id, | |||
| error=f'Child node failed: {event.error}' | |||
| ) | |||
| elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): | |||
| if isinstance(event, QueueIterationNextEvent): | |||
| # clear ran node execution infos of current iteration | |||
| iteration_relations = self._iteration_nested_relations.get(event.node_id) | |||
| if iteration_relations: | |||
| for node_id in iteration_relations: | |||
| self._task_state.ran_node_execution_infos.pop(node_id, None) | |||
| yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) | |||
| self._handle_iteration_operation(event) | |||
| elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): | |||
| workflow_run = self._handle_workflow_finished( | |||
| event, conversation_id=self._conversation.id, trace_manager=trace_manager | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueParallelBranchRunStartedEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| yield self._workflow_parallel_branch_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| if workflow_run: | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run | |||
| ) | |||
| elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| if workflow_run.status == WorkflowRunStatus.FAILED.value: | |||
| err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) | |||
| yield self._error_to_stream_response(self._handle_error(err_event, self._message)) | |||
| break | |||
| yield self._workflow_parallel_branch_finished_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| elif isinstance(event, QueueIterationStartEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| if isinstance(event, QueueStopEvent): | |||
| # Save message | |||
| self._save_message() | |||
| yield self._workflow_iteration_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| elif isinstance(event, QueueIterationNextEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| yield self._message_end_to_stream_response() | |||
| break | |||
| else: | |||
| self._queue_manager.publish( | |||
| QueueAdvancedChatMessageEndEvent(), | |||
| PublishFrom.TASK_PIPELINE | |||
| yield self._workflow_iteration_next_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| elif isinstance(event, QueueIterationCompletedEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| yield self._workflow_iteration_completed_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| elif isinstance(event, QueueWorkflowSucceededEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| if not graph_runtime_state: | |||
| raise Exception('Graph runtime state not initialized.') | |||
| workflow_run = self._handle_workflow_run_success( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=json.dumps(event.outputs) if event.outputs else None, | |||
| conversation_id=self._conversation.id, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run | |||
| ) | |||
| self._queue_manager.publish( | |||
| QueueAdvancedChatMessageEndEvent(), | |||
| PublishFrom.TASK_PIPELINE | |||
| ) | |||
| elif isinstance(event, QueueWorkflowFailedEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| if not graph_runtime_state: | |||
| raise Exception('Graph runtime state not initialized.') | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.FAILED, | |||
| error=event.error, | |||
| conversation_id=self._conversation.id, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run | |||
| ) | |||
| err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) | |||
| yield self._error_to_stream_response(self._handle_error(err_event, self._message)) | |||
| break | |||
| elif isinstance(event, QueueStopEvent): | |||
| if workflow_run and graph_runtime_state: | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.STOPPED, | |||
| error=event.get_stop_reason(), | |||
| conversation_id=self._conversation.id, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run | |||
| ) | |||
| elif isinstance(event, QueueAdvancedChatMessageEndEvent): | |||
| output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) | |||
| if output_moderation_answer: | |||
| self._task_state.answer = output_moderation_answer | |||
| yield self._message_replace_to_stream_response(answer=output_moderation_answer) | |||
| # Save message | |||
| self._save_message() | |||
| self._save_message(graph_runtime_state=graph_runtime_state) | |||
| yield self._message_end_to_stream_response() | |||
| break | |||
| elif isinstance(event, QueueRetrieverResourcesEvent): | |||
| self._handle_retriever_resources(event) | |||
| self._refetch_message() | |||
| self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ | |||
| if self._task_state.metadata else None | |||
| db.session.commit() | |||
| db.session.refresh(self._message) | |||
| db.session.close() | |||
| elif isinstance(event, QueueAnnotationReplyEvent): | |||
| self._handle_annotation_reply(event) | |||
| self._refetch_message() | |||
| self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ | |||
| if self._task_state.metadata else None | |||
| db.session.commit() | |||
| db.session.refresh(self._message) | |||
| db.session.close() | |||
| elif isinstance(event, QueueTextChunkEvent): | |||
| delta_text = event.text | |||
| if delta_text is None: | |||
| continue | |||
| if not self._is_stream_out_support( | |||
| event=event | |||
| ): | |||
| continue | |||
| # handle output moderation chunk | |||
| should_direct_answer = self._handle_output_moderation_chunk(delta_text) | |||
| if should_direct_answer: | |||
| continue | |||
| # only publish tts message at text chunk streaming | |||
| if tts_publisher: | |||
| tts_publisher.publish(message=queue_message) | |||
| self._task_state.answer += delta_text | |||
| yield self._message_to_stream_response(delta_text, self._message.id) | |||
| elif isinstance(event, QueueMessageReplaceEvent): | |||
| # published by moderation | |||
| yield self._message_replace_to_stream_response(answer=event.text) | |||
| elif isinstance(event, QueuePingEvent): | |||
| yield self._ping_stream_response() | |||
| elif isinstance(event, QueueAdvancedChatMessageEndEvent): | |||
| if not graph_runtime_state: | |||
| raise Exception('Graph runtime state not initialized.') | |||
| output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) | |||
| if output_moderation_answer: | |||
| self._task_state.answer = output_moderation_answer | |||
| yield self._message_replace_to_stream_response(answer=output_moderation_answer) | |||
| # Save message | |||
| self._save_message(graph_runtime_state=graph_runtime_state) | |||
| yield self._message_end_to_stream_response() | |||
| else: | |||
| continue | |||
| if publisher: | |||
| publisher.publish(None) | |||
| # publish None when task finished | |||
| if tts_publisher: | |||
| tts_publisher.publish(None) | |||
| if self._conversation_name_generate_thread: | |||
| self._conversation_name_generate_thread.join() | |||
| def _save_message(self) -> None: | |||
| def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: | |||
| """ | |||
| Save message. | |||
| :return: | |||
| """ | |||
| self._message = db.session.query(Message).filter(Message.id == self._message.id).first() | |||
| self._refetch_message() | |||
| self._message.answer = self._task_state.answer | |||
| self._message.provider_response_latency = time.perf_counter() - self._start_at | |||
| self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ | |||
| if self._task_state.metadata else None | |||
| if self._task_state.metadata and self._task_state.metadata.get('usage'): | |||
| usage = LLMUsage(**self._task_state.metadata['usage']) | |||
| if graph_runtime_state and graph_runtime_state.llm_usage: | |||
| usage = graph_runtime_state.llm_usage | |||
| self._message.message_tokens = usage.prompt_tokens | |||
| self._message.message_unit_price = usage.prompt_unit_price | |||
| self._message.message_price_unit = usage.prompt_price_unit | |||
| @@ -432,7 +533,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| """ | |||
| extras = {} | |||
| if self._task_state.metadata: | |||
| extras['metadata'] = self._task_state.metadata | |||
| extras['metadata'] = self._task_state.metadata.copy() | |||
| if 'annotation_reply' in extras['metadata']: | |||
| del extras['metadata']['annotation_reply'] | |||
| return MessageEndStreamResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| @@ -440,323 +544,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| **extras | |||
| ) | |||
| def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]: | |||
| """ | |||
| Get stream generate routes. | |||
| :return: | |||
| """ | |||
| # find all answer nodes | |||
| graph = self._workflow.graph_dict | |||
| answer_node_configs = [ | |||
| node for node in graph['nodes'] | |||
| if node.get('data', {}).get('type') == NodeType.ANSWER.value | |||
| ] | |||
| # parse stream output node value selectors of answer nodes | |||
| stream_generate_routes = {} | |||
| for node_config in answer_node_configs: | |||
| # get generate route for stream output | |||
| answer_node_id = node_config['id'] | |||
| generate_route = AnswerNode.extract_generate_route_selectors(node_config) | |||
| start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id) | |||
| if not start_node_ids: | |||
| continue | |||
| for start_node_id in start_node_ids: | |||
| stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute( | |||
| answer_node_id=answer_node_id, | |||
| generate_route=generate_route | |||
| ) | |||
| return stream_generate_routes | |||
| def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \ | |||
| -> list[str]: | |||
| """ | |||
| Get answer start at node id. | |||
| :param graph: graph | |||
| :param target_node_id: target node ID | |||
| :return: | |||
| """ | |||
| nodes = graph.get('nodes') | |||
| edges = graph.get('edges') | |||
| # fetch all ingoing edges from source node | |||
| ingoing_edges = [] | |||
| for edge in edges: | |||
| if edge.get('target') == target_node_id: | |||
| ingoing_edges.append(edge) | |||
| if not ingoing_edges: | |||
| # check if it's the first node in the iteration | |||
| target_node = next((node for node in nodes if node.get('id') == target_node_id), None) | |||
| if not target_node: | |||
| return [] | |||
| node_iteration_id = target_node.get('data', {}).get('iteration_id') | |||
| # get iteration start node id | |||
| for node in nodes: | |||
| if node.get('id') == node_iteration_id: | |||
| if node.get('data', {}).get('start_node_id') == target_node_id: | |||
| return [target_node_id] | |||
| return [] | |||
| start_node_ids = [] | |||
| for ingoing_edge in ingoing_edges: | |||
| source_node_id = ingoing_edge.get('source') | |||
| source_node = next((node for node in nodes if node.get('id') == source_node_id), None) | |||
| if not source_node: | |||
| continue | |||
| node_type = source_node.get('data', {}).get('type') | |||
| node_iteration_id = source_node.get('data', {}).get('iteration_id') | |||
| iteration_start_node_id = None | |||
| if node_iteration_id: | |||
| iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) | |||
| iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') | |||
| if node_type in [ | |||
| NodeType.ANSWER.value, | |||
| NodeType.IF_ELSE.value, | |||
| NodeType.QUESTION_CLASSIFIER.value, | |||
| NodeType.ITERATION.value, | |||
| NodeType.LOOP.value | |||
| ]: | |||
| start_node_id = target_node_id | |||
| start_node_ids.append(start_node_id) | |||
| elif node_type == NodeType.START.value or \ | |||
| node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): | |||
| start_node_id = source_node_id | |||
| start_node_ids.append(start_node_id) | |||
| else: | |||
| sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id) | |||
| if sub_start_node_ids: | |||
| start_node_ids.extend(sub_start_node_ids) | |||
| return start_node_ids | |||
| def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: | |||
| """ | |||
| Get iteration nested relations. | |||
| :param graph: graph | |||
| :return: | |||
| """ | |||
| nodes = graph.get('nodes') | |||
| iteration_ids = [node.get('id') for node in nodes | |||
| if node.get('data', {}).get('type') in [ | |||
| NodeType.ITERATION.value, | |||
| NodeType.LOOP.value, | |||
| ]] | |||
| return { | |||
| iteration_id: [ | |||
| node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id | |||
| ] for iteration_id in iteration_ids | |||
| } | |||
| def _generate_stream_outputs_when_node_started(self) -> Generator: | |||
| """ | |||
| Generate stream outputs. | |||
| :return: | |||
| """ | |||
| if self._task_state.current_stream_generate_state: | |||
| route_chunks = self._task_state.current_stream_generate_state.generate_route[ | |||
| self._task_state.current_stream_generate_state.current_route_position: | |||
| ] | |||
| for route_chunk in route_chunks: | |||
| if route_chunk.type == 'text': | |||
| route_chunk = cast(TextGenerateRouteChunk, route_chunk) | |||
| # handle output moderation chunk | |||
| should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text) | |||
| if should_direct_answer: | |||
| continue | |||
| self._task_state.answer += route_chunk.text | |||
| yield self._message_to_stream_response(route_chunk.text, self._message.id) | |||
| else: | |||
| break | |||
| self._task_state.current_stream_generate_state.current_route_position += 1 | |||
| # all route chunks are generated | |||
| if self._task_state.current_stream_generate_state.current_route_position == len( | |||
| self._task_state.current_stream_generate_state.generate_route | |||
| ): | |||
| self._task_state.current_stream_generate_state = None | |||
| def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]: | |||
| """ | |||
| Generate stream outputs. | |||
| :return: | |||
| """ | |||
| if not self._task_state.current_stream_generate_state: | |||
| return | |||
| route_chunks = self._task_state.current_stream_generate_state.generate_route[ | |||
| self._task_state.current_stream_generate_state.current_route_position:] | |||
| for route_chunk in route_chunks: | |||
| if route_chunk.type == 'text': | |||
| route_chunk = cast(TextGenerateRouteChunk, route_chunk) | |||
| self._task_state.answer += route_chunk.text | |||
| yield self._message_to_stream_response(route_chunk.text, self._message.id) | |||
| else: | |||
| value = None | |||
| route_chunk = cast(VarGenerateRouteChunk, route_chunk) | |||
| value_selector = route_chunk.value_selector | |||
| if not value_selector: | |||
| self._task_state.current_stream_generate_state.current_route_position += 1 | |||
| continue | |||
| route_chunk_node_id = value_selector[0] | |||
| if route_chunk_node_id == 'sys': | |||
| # system variable | |||
| value = contexts.workflow_variable_pool.get().get(value_selector) | |||
| if value: | |||
| value = value.text | |||
| elif route_chunk_node_id in self._iteration_nested_relations: | |||
| # it's a iteration variable | |||
| if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: | |||
| continue | |||
| iteration_state = self._iteration_state.current_iterations[route_chunk_node_id] | |||
| iterator = iteration_state.inputs | |||
| if not iterator: | |||
| continue | |||
| iterator_selector = iterator.get('iterator_selector', []) | |||
| if value_selector[1] == 'index': | |||
| value = iteration_state.current_index | |||
| elif value_selector[1] == 'item': | |||
| value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len( | |||
| iterator_selector | |||
| ) else None | |||
| else: | |||
| # check chunk node id is before current node id or equal to current node id | |||
| if route_chunk_node_id not in self._task_state.ran_node_execution_infos: | |||
| break | |||
| latest_node_execution_info = self._task_state.latest_node_execution_info | |||
| # get route chunk node execution info | |||
| route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] | |||
| if (route_chunk_node_execution_info.node_type == NodeType.LLM | |||
| and latest_node_execution_info.node_type == NodeType.LLM): | |||
| # only LLM support chunk stream output | |||
| self._task_state.current_stream_generate_state.current_route_position += 1 | |||
| continue | |||
| # get route chunk node execution | |||
| route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( | |||
| WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id | |||
| ).first() | |||
| outputs = route_chunk_node_execution.outputs_dict | |||
| # get value from outputs | |||
| value = None | |||
| for key in value_selector[1:]: | |||
| if not value: | |||
| value = outputs.get(key) if outputs else None | |||
| else: | |||
| value = value.get(key) | |||
| if value is not None: | |||
| text = '' | |||
| if isinstance(value, str | int | float): | |||
| text = str(value) | |||
| elif isinstance(value, FileVar): | |||
| # convert file to markdown | |||
| text = value.to_markdown() | |||
| elif isinstance(value, dict): | |||
| # handle files | |||
| file_vars = self._fetch_files_from_variable_value(value) | |||
| if file_vars: | |||
| file_var = file_vars[0] | |||
| try: | |||
| file_var_obj = FileVar(**file_var) | |||
| # convert file to markdown | |||
| text = file_var_obj.to_markdown() | |||
| except Exception as e: | |||
| logger.error(f'Error creating file var: {e}') | |||
| if not text: | |||
| # other types | |||
| text = json.dumps(value, ensure_ascii=False) | |||
| elif isinstance(value, list): | |||
| # handle files | |||
| file_vars = self._fetch_files_from_variable_value(value) | |||
| for file_var in file_vars: | |||
| try: | |||
| file_var_obj = FileVar(**file_var) | |||
| except Exception as e: | |||
| logger.error(f'Error creating file var: {e}') | |||
| continue | |||
| # convert file to markdown | |||
| text = file_var_obj.to_markdown() + ' ' | |||
| text = text.strip() | |||
| if not text and value: | |||
| # other types | |||
| text = json.dumps(value, ensure_ascii=False) | |||
| if text: | |||
| self._task_state.answer += text | |||
| yield self._message_to_stream_response(text, self._message.id) | |||
| self._task_state.current_stream_generate_state.current_route_position += 1 | |||
| # all route chunks are generated | |||
| if self._task_state.current_stream_generate_state.current_route_position == len( | |||
| self._task_state.current_stream_generate_state.generate_route | |||
| ): | |||
| self._task_state.current_stream_generate_state = None | |||
| def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: | |||
| """ | |||
| Is stream out support | |||
| :param event: queue text chunk event | |||
| :return: | |||
| """ | |||
| if not event.metadata: | |||
| return True | |||
| if 'node_id' not in event.metadata: | |||
| return True | |||
| node_type = event.metadata.get('node_type') | |||
| stream_output_value_selector = event.metadata.get('value_selector') | |||
| if not stream_output_value_selector: | |||
| return False | |||
| if not self._task_state.current_stream_generate_state: | |||
| return False | |||
| route_chunk = self._task_state.current_stream_generate_state.generate_route[ | |||
| self._task_state.current_stream_generate_state.current_route_position] | |||
| if route_chunk.type != 'var': | |||
| return False | |||
| if node_type != NodeType.LLM: | |||
| # only LLM support chunk stream output | |||
| return False | |||
| route_chunk = cast(VarGenerateRouteChunk, route_chunk) | |||
| value_selector = route_chunk.value_selector | |||
| # check chunk node id is before current node id or equal to current node id | |||
| if value_selector != stream_output_value_selector: | |||
| return False | |||
| return True | |||
| def _handle_output_moderation_chunk(self, text: str) -> bool: | |||
| """ | |||
| Handle output moderation chunk. | |||
| @@ -782,3 +569,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| self._output_moderation_handler.append_new_token(text) | |||
| return False | |||
| def _refetch_message(self) -> None: | |||
| """ | |||
| Refetch message. | |||
| :return: | |||
| """ | |||
| message = db.session.query(Message).filter(Message.id == self._message.id).first() | |||
| if message: | |||
| self._message = message | |||
| @@ -1,203 +0,0 @@ | |||
| from typing import Any, Optional | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.entities.queue_entities import ( | |||
| AppQueueEvent, | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueTextChunkEvent, | |||
| QueueWorkflowFailedEvent, | |||
| QueueWorkflowStartedEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| ) | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from models.workflow import Workflow | |||
| class WorkflowEventTriggerCallback(WorkflowCallback): | |||
| def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): | |||
| self._queue_manager = queue_manager | |||
| def on_workflow_run_started(self) -> None: | |||
| """ | |||
| Workflow run started | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueWorkflowStartedEvent(), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_run_succeeded(self) -> None: | |||
| """ | |||
| Workflow run succeeded | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueWorkflowSucceededEvent(), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_run_failed(self, error: str) -> None: | |||
| """ | |||
| Workflow run failed | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueWorkflowFailedEvent( | |||
| error=error | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_node_execute_started(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| node_run_index: int = 1, | |||
| predecessor_node_id: Optional[str] = None) -> None: | |||
| """ | |||
| Workflow node execute started | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueNodeStartedEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| node_data=node_data, | |||
| node_run_index=node_run_index, | |||
| predecessor_node_id=predecessor_node_id | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_node_execute_succeeded(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| inputs: Optional[dict] = None, | |||
| process_data: Optional[dict] = None, | |||
| outputs: Optional[dict] = None, | |||
| execution_metadata: Optional[dict] = None) -> None: | |||
| """ | |||
| Workflow node execute succeeded | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueNodeSucceededEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| node_data=node_data, | |||
| inputs=inputs, | |||
| process_data=process_data, | |||
| outputs=outputs, | |||
| execution_metadata=execution_metadata | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_node_execute_failed(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| error: str, | |||
| inputs: Optional[dict] = None, | |||
| outputs: Optional[dict] = None, | |||
| process_data: Optional[dict] = None) -> None: | |||
| """ | |||
| Workflow node execute failed | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueNodeFailedEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| node_data=node_data, | |||
| inputs=inputs, | |||
| outputs=outputs, | |||
| process_data=process_data, | |||
| error=error | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: | |||
| """ | |||
| Publish text chunk | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueTextChunkEvent( | |||
| text=text, | |||
| metadata={ | |||
| "node_id": node_id, | |||
| **metadata | |||
| } | |||
| ), PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_iteration_started(self, | |||
| node_id: str, | |||
| node_type: NodeType, | |||
| node_run_index: int = 1, | |||
| node_data: Optional[BaseNodeData] = None, | |||
| inputs: dict = None, | |||
| predecessor_node_id: Optional[str] = None, | |||
| metadata: Optional[dict] = None) -> None: | |||
| """ | |||
| Publish iteration started | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueIterationStartEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| node_run_index=node_run_index, | |||
| node_data=node_data, | |||
| inputs=inputs, | |||
| predecessor_node_id=predecessor_node_id, | |||
| metadata=metadata | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_iteration_next(self, node_id: str, | |||
| node_type: NodeType, | |||
| index: int, | |||
| node_run_index: int, | |||
| output: Optional[Any]) -> None: | |||
| """ | |||
| Publish iteration next | |||
| """ | |||
| self._queue_manager._publish( | |||
| QueueIterationNextEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| index=index, | |||
| node_run_index=node_run_index, | |||
| output=output | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_iteration_completed(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_run_index: int, | |||
| outputs: dict) -> None: | |||
| """ | |||
| Publish iteration completed | |||
| """ | |||
| self._queue_manager._publish( | |||
| QueueIterationCompletedEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| node_run_index=node_run_index, | |||
| outputs=outputs | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_event(self, event: AppQueueEvent) -> None: | |||
| """ | |||
| Publish event | |||
| """ | |||
| self._queue_manager.publish( | |||
| event, | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| @@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC): | |||
| def convert(cls, response: Union[ | |||
| AppBlockingResponse, | |||
| Generator[AppStreamResponse, Any, None] | |||
| ], invoke_from: InvokeFrom): | |||
| ], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]: | |||
| if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: | |||
| if isinstance(response, AppBlockingResponse): | |||
| return cls.convert_blocking_full_response(response) | |||
| @@ -1,6 +1,6 @@ | |||
| import time | |||
| from collections.abc import Generator | |||
| from typing import TYPE_CHECKING, Optional, Union | |||
| from collections.abc import Generator, Mapping | |||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||
| from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| @@ -347,7 +347,7 @@ class AppRunner: | |||
| self, app_id: str, | |||
| tenant_id: str, | |||
| app_generate_entity: AppGenerateEntity, | |||
| inputs: dict, | |||
| inputs: Mapping[str, Any], | |||
| query: str, | |||
| message_id: str, | |||
| ) -> tuple[bool, dict, str]: | |||
| @@ -4,7 +4,7 @@ import os | |||
| import threading | |||
| import uuid | |||
| from collections.abc import Generator | |||
| from typing import Literal, Union, overload | |||
| from typing import Any, Literal, Optional, Union, overload | |||
| from flask import Flask, current_app | |||
| from pydantic import ValidationError | |||
| @@ -40,6 +40,8 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| args: dict, | |||
| invoke_from: InvokeFrom, | |||
| stream: Literal[True] = True, | |||
| call_depth: int = 0, | |||
| workflow_thread_pool_id: Optional[str] = None | |||
| ) -> Generator[str, None, None]: ... | |||
| @overload | |||
| @@ -50,16 +52,20 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| args: dict, | |||
| invoke_from: InvokeFrom, | |||
| stream: Literal[False] = False, | |||
| call_depth: int = 0, | |||
| workflow_thread_pool_id: Optional[str] = None | |||
| ) -> dict: ... | |||
| def generate( | |||
| self, app_model: App, | |||
| self, | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: dict, | |||
| invoke_from: InvokeFrom, | |||
| stream: bool = True, | |||
| call_depth: int = 0, | |||
| workflow_thread_pool_id: Optional[str] = None | |||
| ): | |||
| """ | |||
| Generate App response. | |||
| @@ -71,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| :param invoke_from: invoke from source | |||
| :param stream: is stream | |||
| :param call_depth: call depth | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| inputs = args['inputs'] | |||
| @@ -118,16 +125,19 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| application_generate_entity=application_generate_entity, | |||
| invoke_from=invoke_from, | |||
| stream=stream, | |||
| workflow_thread_pool_id=workflow_thread_pool_id | |||
| ) | |||
| def _generate( | |||
| self, app_model: App, | |||
| self, *, | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| application_generate_entity: WorkflowAppGenerateEntity, | |||
| invoke_from: InvokeFrom, | |||
| stream: bool = True, | |||
| ) -> Union[dict, Generator[str, None, None]]: | |||
| workflow_thread_pool_id: Optional[str] = None | |||
| ) -> dict[str, Any] | Generator[str, None, None]: | |||
| """ | |||
| Generate App response. | |||
| @@ -137,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| :param application_generate_entity: application generate entity | |||
| :param invoke_from: invoke from source | |||
| :param stream: is stream | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| # init queue manager | |||
| queue_manager = WorkflowAppQueueManager( | |||
| @@ -148,10 +159,11 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| # new thread | |||
| worker_thread = threading.Thread(target=self._generate_worker, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'flask_app': current_app._get_current_object(), # type: ignore | |||
| 'application_generate_entity': application_generate_entity, | |||
| 'queue_manager': queue_manager, | |||
| 'context': contextvars.copy_context() | |||
| 'context': contextvars.copy_context(), | |||
| 'workflow_thread_pool_id': workflow_thread_pool_id | |||
| }) | |||
| worker_thread.start() | |||
| @@ -175,7 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| node_id: str, | |||
| user: Account, | |||
| args: dict, | |||
| stream: bool = True): | |||
| stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]: | |||
| """ | |||
| Generate App response. | |||
| @@ -192,10 +204,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| if args.get('inputs') is None: | |||
| raise ValueError('inputs is required') | |||
| extras = { | |||
| "auto_generate_conversation_name": False | |||
| } | |||
| # convert to app config | |||
| app_config = WorkflowAppConfigManager.get_app_config( | |||
| app_model=app_model, | |||
| @@ -211,7 +219,9 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| user_id=user.id, | |||
| stream=stream, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| extras=extras, | |||
| extras={ | |||
| "auto_generate_conversation_name": False | |||
| }, | |||
| single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( | |||
| node_id=node_id, | |||
| inputs=args['inputs'] | |||
| @@ -231,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| def _generate_worker(self, flask_app: Flask, | |||
| application_generate_entity: WorkflowAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| context: contextvars.Context) -> None: | |||
| context: contextvars.Context, | |||
| workflow_thread_pool_id: Optional[str] = None) -> None: | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: queue manager | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| :return: | |||
| """ | |||
| for var, val in context.items(): | |||
| @@ -244,22 +256,13 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| with flask_app.app_context(): | |||
| try: | |||
| # workflow app | |||
| runner = WorkflowAppRunner() | |||
| if application_generate_entity.single_iteration_run: | |||
| single_iteration_run = application_generate_entity.single_iteration_run | |||
| runner.single_iteration_run( | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| workflow_id=application_generate_entity.app_config.workflow_id, | |||
| queue_manager=queue_manager, | |||
| inputs=single_iteration_run.inputs, | |||
| node_id=single_iteration_run.node_id, | |||
| user_id=application_generate_entity.user_id | |||
| ) | |||
| else: | |||
| runner.run( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager | |||
| ) | |||
| runner = WorkflowAppRunner( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| workflow_thread_pool_id=workflow_thread_pool_id | |||
| ) | |||
| runner.run() | |||
| except GenerateTaskStoppedException: | |||
| pass | |||
| except InvokeAuthorizationError: | |||
| @@ -271,14 +274,14 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| logger.exception("Validation Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| except (ValueError, InvokeError) as e: | |||
| if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': | |||
| if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true': | |||
| logger.exception("Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| except Exception as e: | |||
| logger.exception("Unknown Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| finally: | |||
| db.session.remove() | |||
| db.session.close() | |||
| def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, | |||
| workflow: Workflow, | |||
| @@ -4,46 +4,61 @@ from typing import Optional, cast | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfig | |||
| from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback | |||
| from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner | |||
| from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback | |||
| from core.app.entities.app_invoke_entities import ( | |||
| InvokeFrom, | |||
| WorkflowAppGenerateEntity, | |||
| ) | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.entities.node_entities import UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.workflow_engine_manager import WorkflowEngineManager | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models.model import App, EndUser | |||
| from models.workflow import Workflow | |||
| from models.workflow import WorkflowType | |||
| logger = logging.getLogger(__name__) | |||
| class WorkflowAppRunner: | |||
| class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| """ | |||
| Workflow Application Runner | |||
| """ | |||
| def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: | |||
| def __init__( | |||
| self, | |||
| application_generate_entity: WorkflowAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| workflow_thread_pool_id: Optional[str] = None | |||
| ) -> None: | |||
| """ | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: application queue manager | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| self.application_generate_entity = application_generate_entity | |||
| self.queue_manager = queue_manager | |||
| self.workflow_thread_pool_id = workflow_thread_pool_id | |||
| def run(self) -> None: | |||
| """ | |||
| Run application | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: application queue manager | |||
| :return: | |||
| """ | |||
| app_config = application_generate_entity.app_config | |||
| app_config = self.application_generate_entity.app_config | |||
| app_config = cast(WorkflowAppConfig, app_config) | |||
| user_id = None | |||
| if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() | |||
| if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() | |||
| if end_user: | |||
| user_id = end_user.session_id | |||
| else: | |||
| user_id = application_generate_entity.user_id | |||
| user_id = self.application_generate_entity.user_id | |||
| app_record = db.session.query(App).filter(App.id == app_config.app_id).first() | |||
| if not app_record: | |||
| @@ -53,80 +68,64 @@ class WorkflowAppRunner: | |||
| if not workflow: | |||
| raise ValueError('Workflow not initialized') | |||
| inputs = application_generate_entity.inputs | |||
| files = application_generate_entity.files | |||
| db.session.close() | |||
| workflow_callbacks: list[WorkflowCallback] = [ | |||
| WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) | |||
| ] | |||
| workflow_callbacks: list[WorkflowCallback] = [] | |||
| if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| # Create a variable pool. | |||
| system_inputs = { | |||
| SystemVariableKey.FILES: files, | |||
| SystemVariableKey.USER_ID: user_id, | |||
| } | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| user_inputs=inputs, | |||
| environment_variables=workflow.environment_variables, | |||
| conversation_variables=[], | |||
| ) | |||
| # if only single iteration run is requested | |||
| if self.application_generate_entity.single_iteration_run: | |||
| # if only single iteration run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | |||
| workflow=workflow, | |||
| node_id=self.application_generate_entity.single_iteration_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_iteration_run.inputs | |||
| ) | |||
| else: | |||
| # RUN WORKFLOW | |||
| workflow_engine_manager = WorkflowEngineManager() | |||
| workflow_engine_manager.run_workflow( | |||
| workflow=workflow, | |||
| user_id=application_generate_entity.user_id, | |||
| user_from=UserFrom.ACCOUNT | |||
| if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] | |||
| else UserFrom.END_USER, | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| callbacks=workflow_callbacks, | |||
| call_depth=application_generate_entity.call_depth, | |||
| variable_pool=variable_pool, | |||
| ) | |||
| inputs = self.application_generate_entity.inputs | |||
| files = self.application_generate_entity.files | |||
| def single_iteration_run( | |||
| self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str | |||
| ) -> None: | |||
| """ | |||
| Single iteration run | |||
| """ | |||
| app_record = db.session.query(App).filter(App.id == app_id).first() | |||
| if not app_record: | |||
| raise ValueError('App not found') | |||
| # Create a variable pool. | |||
| system_inputs = { | |||
| SystemVariableKey.FILES: files, | |||
| SystemVariableKey.USER_ID: user_id, | |||
| } | |||
| if not app_record.workflow_id: | |||
| raise ValueError('Workflow not initialized') | |||
| workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) | |||
| if not workflow: | |||
| raise ValueError('Workflow not initialized') | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| user_inputs=inputs, | |||
| environment_variables=workflow.environment_variables, | |||
| conversation_variables=[], | |||
| ) | |||
| workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] | |||
| # init graph | |||
| graph = self._init_graph(graph_config=workflow.graph_dict) | |||
| workflow_engine_manager = WorkflowEngineManager() | |||
| workflow_engine_manager.single_step_run_iteration_workflow_node( | |||
| workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks | |||
| # RUN WORKFLOW | |||
| workflow_entry = WorkflowEntry( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=workflow.app_id, | |||
| workflow_id=workflow.id, | |||
| workflow_type=WorkflowType.value_of(workflow.type), | |||
| graph=graph, | |||
| graph_config=workflow.graph_dict, | |||
| user_id=self.application_generate_entity.user_id, | |||
| user_from=( | |||
| UserFrom.ACCOUNT | |||
| if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] | |||
| else UserFrom.END_USER | |||
| ), | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| call_depth=self.application_generate_entity.call_depth, | |||
| variable_pool=variable_pool, | |||
| thread_pool_id=self.workflow_thread_pool_id | |||
| ) | |||
| def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: | |||
| """ | |||
| Get workflow | |||
| """ | |||
| # fetch workflow by workflow_id | |||
| workflow = ( | |||
| db.session.query(Workflow) | |||
| .filter( | |||
| Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id | |||
| ) | |||
| .first() | |||
| generator = workflow_entry.run( | |||
| callbacks=workflow_callbacks | |||
| ) | |||
| # return workflow | |||
| return workflow | |||
| for event in generator: | |||
| self._handle_event(workflow_entry, event) | |||
| @@ -1,3 +1,4 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| from collections.abc import Generator | |||
| @@ -15,10 +16,12 @@ from core.app.entities.queue_entities import ( | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueMessageReplaceEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| QueueParallelBranchRunStartedEvent, | |||
| QueueParallelBranchRunSucceededEvent, | |||
| QueuePingEvent, | |||
| QueueStopEvent, | |||
| QueueTextChunkEvent, | |||
| @@ -32,19 +35,16 @@ from core.app.entities.task_entities import ( | |||
| MessageAudioStreamResponse, | |||
| StreamResponse, | |||
| TextChunkStreamResponse, | |||
| TextReplaceStreamResponse, | |||
| WorkflowAppBlockingResponse, | |||
| WorkflowAppStreamResponse, | |||
| WorkflowFinishStreamResponse, | |||
| WorkflowStreamGenerateNodes, | |||
| WorkflowStartStreamResponse, | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.end.end_node import EndNode | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import EndUser | |||
| @@ -52,8 +52,8 @@ from models.workflow import ( | |||
| Workflow, | |||
| WorkflowAppLog, | |||
| WorkflowAppLogCreatedFrom, | |||
| WorkflowNodeExecution, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| @@ -68,7 +68,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| _task_state: WorkflowTaskState | |||
| _application_generate_entity: WorkflowAppGenerateEntity | |||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||
| _iteration_nested_relations: dict[str, list[str]] | |||
| def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, | |||
| workflow: Workflow, | |||
| @@ -96,11 +95,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| SystemVariableKey.USER_ID: user_id | |||
| } | |||
| self._task_state = WorkflowTaskState( | |||
| iteration_nested_node_ids=[] | |||
| ) | |||
| self._stream_generate_nodes = self._get_stream_generate_nodes() | |||
| self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) | |||
| self._task_state = WorkflowTaskState() | |||
| def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: | |||
| """ | |||
| @@ -129,23 +124,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if isinstance(stream_response, ErrorStreamResponse): | |||
| raise stream_response.err | |||
| elif isinstance(stream_response, WorkflowFinishStreamResponse): | |||
| workflow_run = db.session.query(WorkflowRun).filter( | |||
| WorkflowRun.id == self._task_state.workflow_run_id).first() | |||
| response = WorkflowAppBlockingResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| workflow_run_id=stream_response.data.id, | |||
| data=WorkflowAppBlockingResponse.Data( | |||
| id=workflow_run.id, | |||
| workflow_id=workflow_run.workflow_id, | |||
| status=workflow_run.status, | |||
| outputs=workflow_run.outputs_dict, | |||
| error=workflow_run.error, | |||
| elapsed_time=workflow_run.elapsed_time, | |||
| total_tokens=workflow_run.total_tokens, | |||
| total_steps=workflow_run.total_steps, | |||
| created_at=int(workflow_run.created_at.timestamp()), | |||
| finished_at=int(workflow_run.finished_at.timestamp()) | |||
| id=stream_response.data.id, | |||
| workflow_id=stream_response.data.workflow_id, | |||
| status=stream_response.data.status, | |||
| outputs=stream_response.data.outputs, | |||
| error=stream_response.data.error, | |||
| elapsed_time=stream_response.data.elapsed_time, | |||
| total_tokens=stream_response.data.total_tokens, | |||
| total_steps=stream_response.data.total_steps, | |||
| created_at=int(stream_response.data.created_at), | |||
| finished_at=int(stream_response.data.finished_at) | |||
| ) | |||
| ) | |||
| @@ -161,9 +153,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| To stream response. | |||
| :return: | |||
| """ | |||
| workflow_run_id = None | |||
| for stream_response in generator: | |||
| if isinstance(stream_response, WorkflowStartStreamResponse): | |||
| workflow_run_id = stream_response.workflow_run_id | |||
| yield WorkflowAppStreamResponse( | |||
| workflow_run_id=self._task_state.workflow_run_id, | |||
| workflow_run_id=workflow_run_id, | |||
| stream_response=stream_response | |||
| ) | |||
| @@ -178,17 +174,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ | |||
| Generator[StreamResponse, None, None]: | |||
| publisher = None | |||
| tts_publisher = None | |||
| task_id = self._application_generate_entity.task_id | |||
| tenant_id = self._application_generate_entity.app_config.tenant_id | |||
| features_dict = self._workflow.features_dict | |||
| if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ | |||
| 'text_to_speech'].get('autoPlay') == 'enabled': | |||
| publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) | |||
| for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): | |||
| tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) | |||
| for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): | |||
| while True: | |||
| audio_response = self._listenAudioMsg(publisher, task_id=task_id) | |||
| audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) | |||
| if audio_response: | |||
| yield audio_response | |||
| else: | |||
| @@ -198,9 +195,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| start_listener_time = time.time() | |||
| while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: | |||
| try: | |||
| if not publisher: | |||
| if not tts_publisher: | |||
| break | |||
| audio_trunk = publisher.checkAndGetAudio() | |||
| audio_trunk = tts_publisher.checkAndGetAudio() | |||
| if audio_trunk is None: | |||
| # release cpu | |||
| # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) | |||
| @@ -218,69 +215,159 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| def _process_stream_response( | |||
| self, | |||
| publisher: AppGeneratorTTSPublisher, | |||
| tts_publisher: Optional[AppGeneratorTTSPublisher] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| """ | |||
| Process stream response. | |||
| :return: | |||
| """ | |||
| for message in self._queue_manager.listen(): | |||
| if publisher: | |||
| publisher.publish(message=message) | |||
| event = message.event | |||
| graph_runtime_state = None | |||
| workflow_run = None | |||
| if isinstance(event, QueueErrorEvent): | |||
| for queue_message in self._queue_manager.listen(): | |||
| event = queue_message.event | |||
| if isinstance(event, QueuePingEvent): | |||
| yield self._ping_stream_response() | |||
| elif isinstance(event, QueueErrorEvent): | |||
| err = self._handle_error(event) | |||
| yield self._error_to_stream_response(err) | |||
| break | |||
| elif isinstance(event, QueueWorkflowStartedEvent): | |||
| workflow_run = self._handle_workflow_start() | |||
| # override graph runtime state | |||
| graph_runtime_state = event.graph_runtime_state | |||
| # init workflow run | |||
| workflow_run = self._handle_workflow_run_start() | |||
| yield self._workflow_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run | |||
| ) | |||
| elif isinstance(event, QueueNodeStartedEvent): | |||
| workflow_node_execution = self._handle_node_start(event) | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| # search stream_generate_routes if node id is answer start at node | |||
| if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes: | |||
| self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id] | |||
| workflow_node_execution = self._handle_node_execution_start( | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| response = self._workflow_node_start_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution | |||
| ) | |||
| # generate stream outputs when node started | |||
| yield from self._generate_stream_outputs_when_node_started() | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueNodeSucceededEvent): | |||
| workflow_node_execution = self._handle_workflow_node_execution_success(event) | |||
| yield self._workflow_node_start_to_stream_response( | |||
| response = self._workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution | |||
| ) | |||
| elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): | |||
| workflow_node_execution = self._handle_node_finished(event) | |||
| yield self._workflow_node_finish_to_stream_response( | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueNodeFailedEvent): | |||
| workflow_node_execution = self._handle_workflow_node_execution_failed(event) | |||
| response = self._workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution | |||
| ) | |||
| if isinstance(event, QueueNodeFailedEvent): | |||
| yield from self._handle_iteration_exception( | |||
| task_id=self._application_generate_entity.task_id, | |||
| error=f'Child node failed: {event.error}' | |||
| ) | |||
| elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): | |||
| if isinstance(event, QueueIterationNextEvent): | |||
| # clear ran node execution infos of current iteration | |||
| iteration_relations = self._iteration_nested_relations.get(event.node_id) | |||
| if iteration_relations: | |||
| for node_id in iteration_relations: | |||
| self._task_state.ran_node_execution_infos.pop(node_id, None) | |||
| yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) | |||
| self._handle_iteration_operation(event) | |||
| elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): | |||
| workflow_run = self._handle_workflow_finished( | |||
| event, trace_manager=trace_manager | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueParallelBranchRunStartedEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| yield self._workflow_parallel_branch_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| yield self._workflow_parallel_branch_finished_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| elif isinstance(event, QueueIterationStartEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| yield self._workflow_iteration_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| elif isinstance(event, QueueIterationNextEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| yield self._workflow_iteration_next_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| elif isinstance(event, QueueIterationCompletedEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| yield self._workflow_iteration_completed_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event | |||
| ) | |||
| elif isinstance(event, QueueWorkflowSucceededEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| if not graph_runtime_state: | |||
| raise Exception('Graph runtime state not initialized.') | |||
| workflow_run = self._handle_workflow_run_success( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(workflow_run) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run | |||
| ) | |||
| elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): | |||
| if not workflow_run: | |||
| raise Exception('Workflow run not initialized.') | |||
| if not graph_runtime_state: | |||
| raise Exception('Graph runtime state not initialized.') | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED, | |||
| error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| # save workflow app log | |||
| @@ -295,22 +382,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if delta_text is None: | |||
| continue | |||
| if not self._is_stream_out_support( | |||
| event=event | |||
| ): | |||
| continue | |||
| # only publish tts message at text chunk streaming | |||
| if tts_publisher: | |||
| tts_publisher.publish(message=queue_message) | |||
| self._task_state.answer += delta_text | |||
| yield self._text_chunk_to_stream_response(delta_text) | |||
| elif isinstance(event, QueueMessageReplaceEvent): | |||
| yield self._text_replace_to_stream_response(event.text) | |||
| elif isinstance(event, QueuePingEvent): | |||
| yield self._ping_stream_response() | |||
| else: | |||
| continue | |||
| if publisher: | |||
| publisher.publish(None) | |||
| if tts_publisher: | |||
| tts_publisher.publish(None) | |||
| def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: | |||
| @@ -329,15 +411,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| # not save log for debugging | |||
| return | |||
| workflow_app_log = WorkflowAppLog( | |||
| tenant_id=workflow_run.tenant_id, | |||
| app_id=workflow_run.app_id, | |||
| workflow_id=workflow_run.workflow_id, | |||
| workflow_run_id=workflow_run.id, | |||
| created_from=created_from.value, | |||
| created_by_role=('account' if isinstance(self._user, Account) else 'end_user'), | |||
| created_by=self._user.id, | |||
| ) | |||
| workflow_app_log = WorkflowAppLog() | |||
| workflow_app_log.tenant_id = workflow_run.tenant_id | |||
| workflow_app_log.app_id = workflow_run.app_id | |||
| workflow_app_log.workflow_id = workflow_run.workflow_id | |||
| workflow_app_log.workflow_run_id = workflow_run.id | |||
| workflow_app_log.created_from = created_from.value | |||
| workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user' | |||
| workflow_app_log.created_by = self._user.id | |||
| db.session.add(workflow_app_log) | |||
| db.session.commit() | |||
| db.session.close() | |||
| @@ -354,180 +436,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| ) | |||
| return response | |||
| def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse: | |||
| """ | |||
| Text replace to stream response. | |||
| :param text: text | |||
| :return: | |||
| """ | |||
| return TextReplaceStreamResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| text=TextReplaceStreamResponse.Data(text=text) | |||
| ) | |||
| def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]: | |||
| """ | |||
| Get stream generate nodes. | |||
| :return: | |||
| """ | |||
| # find all answer nodes | |||
| graph = self._workflow.graph_dict | |||
| end_node_configs = [ | |||
| node for node in graph['nodes'] | |||
| if node.get('data', {}).get('type') == NodeType.END.value | |||
| ] | |||
| # parse stream output node value selectors of end nodes | |||
| stream_generate_routes = {} | |||
| for node_config in end_node_configs: | |||
| # get generate route for stream output | |||
| end_node_id = node_config['id'] | |||
| generate_nodes = EndNode.extract_generate_nodes(graph, node_config) | |||
| start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id) | |||
| if not start_node_ids: | |||
| continue | |||
| for start_node_id in start_node_ids: | |||
| stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes( | |||
| end_node_id=end_node_id, | |||
| stream_node_ids=generate_nodes | |||
| ) | |||
| return stream_generate_routes | |||
| def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \ | |||
| -> list[str]: | |||
| """ | |||
| Get end start at node id. | |||
| :param graph: graph | |||
| :param target_node_id: target node ID | |||
| :return: | |||
| """ | |||
| nodes = graph.get('nodes') | |||
| edges = graph.get('edges') | |||
| # fetch all ingoing edges from source node | |||
| ingoing_edges = [] | |||
| for edge in edges: | |||
| if edge.get('target') == target_node_id: | |||
| ingoing_edges.append(edge) | |||
| if not ingoing_edges: | |||
| return [] | |||
| start_node_ids = [] | |||
| for ingoing_edge in ingoing_edges: | |||
| source_node_id = ingoing_edge.get('source') | |||
| source_node = next((node for node in nodes if node.get('id') == source_node_id), None) | |||
| if not source_node: | |||
| continue | |||
| node_type = source_node.get('data', {}).get('type') | |||
| node_iteration_id = source_node.get('data', {}).get('iteration_id') | |||
| iteration_start_node_id = None | |||
| if node_iteration_id: | |||
| iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) | |||
| iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') | |||
| if node_type in [ | |||
| NodeType.IF_ELSE.value, | |||
| NodeType.QUESTION_CLASSIFIER.value | |||
| ]: | |||
| start_node_id = target_node_id | |||
| start_node_ids.append(start_node_id) | |||
| elif node_type == NodeType.START.value or \ | |||
| node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): | |||
| start_node_id = source_node_id | |||
| start_node_ids.append(start_node_id) | |||
| else: | |||
| sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id) | |||
| if sub_start_node_ids: | |||
| start_node_ids.extend(sub_start_node_ids) | |||
| return start_node_ids | |||
| def _generate_stream_outputs_when_node_started(self) -> Generator: | |||
| """ | |||
| Generate stream outputs. | |||
| :return: | |||
| """ | |||
| if self._task_state.current_stream_generate_state: | |||
| stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids | |||
| for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items(): | |||
| if node_id not in stream_node_ids: | |||
| continue | |||
| node_execution_info = self._task_state.ran_node_execution_infos[node_id] | |||
| # get chunk node execution | |||
| route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( | |||
| WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first() | |||
| if not route_chunk_node_execution: | |||
| continue | |||
| outputs = route_chunk_node_execution.outputs_dict | |||
| if not outputs: | |||
| continue | |||
| # get value from outputs | |||
| text = outputs.get('text') | |||
| if text: | |||
| self._task_state.answer += text | |||
| yield self._text_chunk_to_stream_response(text) | |||
| db.session.close() | |||
| def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: | |||
| """ | |||
| Is stream out support | |||
| :param event: queue text chunk event | |||
| :return: | |||
| """ | |||
| if not event.metadata: | |||
| return False | |||
| if 'node_id' not in event.metadata: | |||
| return False | |||
| node_id = event.metadata.get('node_id') | |||
| node_type = event.metadata.get('node_type') | |||
| stream_output_value_selector = event.metadata.get('value_selector') | |||
| if not stream_output_value_selector: | |||
| return False | |||
| if not self._task_state.current_stream_generate_state: | |||
| return False | |||
| if node_id not in self._task_state.current_stream_generate_state.stream_node_ids: | |||
| return False | |||
| if node_type != NodeType.LLM: | |||
| # only LLM support chunk stream output | |||
| return False | |||
| return True | |||
| def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: | |||
| """ | |||
| Get iteration nested relations. | |||
| :param graph: graph | |||
| :return: | |||
| """ | |||
| nodes = graph.get('nodes') | |||
| iteration_ids = [node.get('id') for node in nodes | |||
| if node.get('data', {}).get('type') in [ | |||
| NodeType.ITERATION.value, | |||
| NodeType.LOOP.value, | |||
| ]] | |||
| return { | |||
| iteration_id: [ | |||
| node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id | |||
| ] for iteration_id in iteration_ids | |||
| } | |||
| @@ -1,200 +0,0 @@ | |||
| from typing import Any, Optional | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.entities.queue_entities import ( | |||
| AppQueueEvent, | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueTextChunkEvent, | |||
| QueueWorkflowFailedEvent, | |||
| QueueWorkflowStartedEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| ) | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from models.workflow import Workflow | |||
| class WorkflowEventTriggerCallback(WorkflowCallback): | |||
| def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): | |||
| self._queue_manager = queue_manager | |||
| def on_workflow_run_started(self) -> None: | |||
| """ | |||
| Workflow run started | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueWorkflowStartedEvent(), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_run_succeeded(self) -> None: | |||
| """ | |||
| Workflow run succeeded | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueWorkflowSucceededEvent(), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_run_failed(self, error: str) -> None: | |||
| """ | |||
| Workflow run failed | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueWorkflowFailedEvent( | |||
| error=error | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_node_execute_started(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| node_run_index: int = 1, | |||
| predecessor_node_id: Optional[str] = None) -> None: | |||
| """ | |||
| Workflow node execute started | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueNodeStartedEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| node_data=node_data, | |||
| node_run_index=node_run_index, | |||
| predecessor_node_id=predecessor_node_id | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_node_execute_succeeded(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| inputs: Optional[dict] = None, | |||
| process_data: Optional[dict] = None, | |||
| outputs: Optional[dict] = None, | |||
| execution_metadata: Optional[dict] = None) -> None: | |||
| """ | |||
| Workflow node execute succeeded | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueNodeSucceededEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| node_data=node_data, | |||
| inputs=inputs, | |||
| process_data=process_data, | |||
| outputs=outputs, | |||
| execution_metadata=execution_metadata | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_node_execute_failed(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| error: str, | |||
| inputs: Optional[dict] = None, | |||
| outputs: Optional[dict] = None, | |||
| process_data: Optional[dict] = None) -> None: | |||
| """ | |||
| Workflow node execute failed | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueNodeFailedEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| node_data=node_data, | |||
| inputs=inputs, | |||
| outputs=outputs, | |||
| process_data=process_data, | |||
| error=error | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: | |||
| """ | |||
| Publish text chunk | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueTextChunkEvent( | |||
| text=text, | |||
| metadata={ | |||
| "node_id": node_id, | |||
| **metadata | |||
| } | |||
| ), PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_iteration_started(self, | |||
| node_id: str, | |||
| node_type: NodeType, | |||
| node_run_index: int = 1, | |||
| node_data: Optional[BaseNodeData] = None, | |||
| inputs: dict = None, | |||
| predecessor_node_id: Optional[str] = None, | |||
| metadata: Optional[dict] = None) -> None: | |||
| """ | |||
| Publish iteration started | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueIterationStartEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| node_run_index=node_run_index, | |||
| node_data=node_data, | |||
| inputs=inputs, | |||
| predecessor_node_id=predecessor_node_id, | |||
| metadata=metadata | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_iteration_next(self, node_id: str, | |||
| node_type: NodeType, | |||
| index: int, | |||
| node_run_index: int, | |||
| output: Optional[Any]) -> None: | |||
| """ | |||
| Publish iteration next | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueIterationNextEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| index=index, | |||
| node_run_index=node_run_index, | |||
| output=output | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_workflow_iteration_completed(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_run_index: int, | |||
| outputs: dict) -> None: | |||
| """ | |||
| Publish iteration completed | |||
| """ | |||
| self._queue_manager.publish( | |||
| QueueIterationCompletedEvent( | |||
| node_id=node_id, | |||
| node_type=node_type, | |||
| node_run_index=node_run_index, | |||
| outputs=outputs | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| def on_event(self, event: AppQueueEvent) -> None: | |||
| """ | |||
| Publish event | |||
| """ | |||
| pass | |||
| @@ -0,0 +1,379 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional, cast | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.apps.base_app_runner import AppRunner | |||
| from core.app.entities.queue_entities import ( | |||
| AppQueueEvent, | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| QueueParallelBranchRunStartedEvent, | |||
| QueueParallelBranchRunSucceededEvent, | |||
| QueueRetrieverResourcesEvent, | |||
| QueueTextChunkEvent, | |||
| QueueWorkflowFailedEvent, | |||
| QueueWorkflowStartedEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| ) | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| GraphEngineEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunStartedEvent, | |||
| GraphRunSucceededEvent, | |||
| IterationRunFailedEvent, | |||
| IterationRunNextEvent, | |||
| IterationRunStartedEvent, | |||
| IterationRunSucceededEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunRetrieverResourceEvent, | |||
| NodeRunStartedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| NodeRunSucceededEvent, | |||
| ParallelBranchRunFailedEvent, | |||
| ParallelBranchRunStartedEvent, | |||
| ParallelBranchRunSucceededEvent, | |||
| ) | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.iteration.entities import IterationNodeData | |||
| from core.workflow.nodes.node_mapping import node_classes | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models.model import App | |||
| from models.workflow import Workflow | |||
| class WorkflowBasedAppRunner(AppRunner): | |||
| def __init__(self, queue_manager: AppQueueManager): | |||
| self.queue_manager = queue_manager | |||
| def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: | |||
| """ | |||
| Init graph | |||
| """ | |||
| if 'nodes' not in graph_config or 'edges' not in graph_config: | |||
| raise ValueError('nodes or edges not found in workflow graph') | |||
| if not isinstance(graph_config.get('nodes'), list): | |||
| raise ValueError('nodes in workflow graph must be a list') | |||
| if not isinstance(graph_config.get('edges'), list): | |||
| raise ValueError('edges in workflow graph must be a list') | |||
| # init graph | |||
| graph = Graph.init( | |||
| graph_config=graph_config | |||
| ) | |||
| if not graph: | |||
| raise ValueError('graph not found in workflow') | |||
| return graph | |||
| def _get_graph_and_variable_pool_of_single_iteration( | |||
| self, | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user_inputs: dict, | |||
| ) -> tuple[Graph, VariablePool]: | |||
| """ | |||
| Get variable pool of single iteration | |||
| """ | |||
| # fetch workflow graph | |||
| graph_config = workflow.graph_dict | |||
| if not graph_config: | |||
| raise ValueError('workflow graph not found') | |||
| graph_config = cast(dict[str, Any], graph_config) | |||
| if 'nodes' not in graph_config or 'edges' not in graph_config: | |||
| raise ValueError('nodes or edges not found in workflow graph') | |||
| if not isinstance(graph_config.get('nodes'), list): | |||
| raise ValueError('nodes in workflow graph must be a list') | |||
| if not isinstance(graph_config.get('edges'), list): | |||
| raise ValueError('edges in workflow graph must be a list') | |||
| # filter nodes only in iteration | |||
| node_configs = [ | |||
| node for node in graph_config.get('nodes', []) | |||
| if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id | |||
| ] | |||
| graph_config['nodes'] = node_configs | |||
| node_ids = [node.get('id') for node in node_configs] | |||
| # filter edges only in iteration | |||
| edge_configs = [ | |||
| edge for edge in graph_config.get('edges', []) | |||
| if (edge.get('source') is None or edge.get('source') in node_ids) | |||
| and (edge.get('target') is None or edge.get('target') in node_ids) | |||
| ] | |||
| graph_config['edges'] = edge_configs | |||
| # init graph | |||
| graph = Graph.init( | |||
| graph_config=graph_config, | |||
| root_node_id=node_id | |||
| ) | |||
| if not graph: | |||
| raise ValueError('graph not found in workflow') | |||
| # fetch node config from node id | |||
| iteration_node_config = None | |||
| for node in node_configs: | |||
| if node.get('id') == node_id: | |||
| iteration_node_config = node | |||
| break | |||
| if not iteration_node_config: | |||
| raise ValueError('iteration node id not found in workflow graph') | |||
| # Get node class | |||
| node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type')) | |||
| node_cls = node_classes.get(node_type) | |||
| node_cls = cast(type[BaseNode], node_cls) | |||
| # init variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables={}, | |||
| user_inputs={}, | |||
| environment_variables=workflow.environment_variables, | |||
| ) | |||
| try: | |||
| variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( | |||
| graph_config=workflow.graph_dict, | |||
| config=iteration_node_config | |||
| ) | |||
| except NotImplementedError: | |||
| variable_mapping = {} | |||
| WorkflowEntry.mapping_user_inputs_to_variable_pool( | |||
| variable_mapping=variable_mapping, | |||
| user_inputs=user_inputs, | |||
| variable_pool=variable_pool, | |||
| tenant_id=workflow.tenant_id, | |||
| node_type=node_type, | |||
| node_data=IterationNodeData(**iteration_node_config.get('data', {})) | |||
| ) | |||
| return graph, variable_pool | |||
| def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: | |||
| """ | |||
| Handle event | |||
| :param workflow_entry: workflow entry | |||
| :param event: event | |||
| """ | |||
| if isinstance(event, GraphRunStartedEvent): | |||
| self._publish_event( | |||
| QueueWorkflowStartedEvent( | |||
| graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state | |||
| ) | |||
| ) | |||
| elif isinstance(event, GraphRunSucceededEvent): | |||
| self._publish_event( | |||
| QueueWorkflowSucceededEvent(outputs=event.outputs) | |||
| ) | |||
| elif isinstance(event, GraphRunFailedEvent): | |||
| self._publish_event( | |||
| QueueWorkflowFailedEvent(error=event.error) | |||
| ) | |||
| elif isinstance(event, NodeRunStartedEvent): | |||
| self._publish_event( | |||
| QueueNodeStartedEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.route_node_state.start_at, | |||
| node_run_index=event.route_node_state.index, | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| in_iteration_id=event.in_iteration_id | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunSucceededEvent): | |||
| self._publish_event( | |||
| QueueNodeSucceededEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.route_node_state.start_at, | |||
| inputs=event.route_node_state.node_run_result.inputs | |||
| if event.route_node_state.node_run_result else {}, | |||
| process_data=event.route_node_state.node_run_result.process_data | |||
| if event.route_node_state.node_run_result else {}, | |||
| outputs=event.route_node_state.node_run_result.outputs | |||
| if event.route_node_state.node_run_result else {}, | |||
| execution_metadata=event.route_node_state.node_run_result.metadata | |||
| if event.route_node_state.node_run_result else {}, | |||
| in_iteration_id=event.in_iteration_id | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunFailedEvent): | |||
| self._publish_event( | |||
| QueueNodeFailedEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.route_node_state.start_at, | |||
| inputs=event.route_node_state.node_run_result.inputs | |||
| if event.route_node_state.node_run_result else {}, | |||
| process_data=event.route_node_state.node_run_result.process_data | |||
| if event.route_node_state.node_run_result else {}, | |||
| outputs=event.route_node_state.node_run_result.outputs | |||
| if event.route_node_state.node_run_result else {}, | |||
| error=event.route_node_state.node_run_result.error | |||
| if event.route_node_state.node_run_result | |||
| and event.route_node_state.node_run_result.error | |||
| else "Unknown error", | |||
| in_iteration_id=event.in_iteration_id | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunStreamChunkEvent): | |||
| self._publish_event( | |||
| QueueTextChunkEvent( | |||
| text=event.chunk_content, | |||
| from_variable_selector=event.from_variable_selector, | |||
| in_iteration_id=event.in_iteration_id | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunRetrieverResourceEvent): | |||
| self._publish_event( | |||
| QueueRetrieverResourcesEvent( | |||
| retriever_resources=event.retriever_resources, | |||
| in_iteration_id=event.in_iteration_id | |||
| ) | |||
| ) | |||
| elif isinstance(event, ParallelBranchRunStartedEvent): | |||
| self._publish_event( | |||
| QueueParallelBranchRunStartedEvent( | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| in_iteration_id=event.in_iteration_id | |||
| ) | |||
| ) | |||
| elif isinstance(event, ParallelBranchRunSucceededEvent): | |||
| self._publish_event( | |||
| QueueParallelBranchRunSucceededEvent( | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| in_iteration_id=event.in_iteration_id | |||
| ) | |||
| ) | |||
| elif isinstance(event, ParallelBranchRunFailedEvent): | |||
| self._publish_event( | |||
| QueueParallelBranchRunFailedEvent( | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| in_iteration_id=event.in_iteration_id, | |||
| error=event.error | |||
| ) | |||
| ) | |||
| elif isinstance(event, IterationRunStartedEvent): | |||
| self._publish_event( | |||
| QueueIterationStartEvent( | |||
| node_execution_id=event.iteration_id, | |||
| node_id=event.iteration_node_id, | |||
| node_type=event.iteration_node_type, | |||
| node_data=event.iteration_node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.start_at, | |||
| node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, | |||
| inputs=event.inputs, | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| metadata=event.metadata | |||
| ) | |||
| ) | |||
| elif isinstance(event, IterationRunNextEvent): | |||
| self._publish_event( | |||
| QueueIterationNextEvent( | |||
| node_execution_id=event.iteration_id, | |||
| node_id=event.iteration_node_id, | |||
| node_type=event.iteration_node_type, | |||
| node_data=event.iteration_node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| index=event.index, | |||
| node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, | |||
| output=event.pre_iteration_output, | |||
| ) | |||
| ) | |||
| elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): | |||
| self._publish_event( | |||
| QueueIterationCompletedEvent( | |||
| node_execution_id=event.iteration_id, | |||
| node_id=event.iteration_node_id, | |||
| node_type=event.iteration_node_type, | |||
| node_data=event.iteration_node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.start_at, | |||
| node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, | |||
| inputs=event.inputs, | |||
| outputs=event.outputs, | |||
| metadata=event.metadata, | |||
| steps=event.steps, | |||
| error=event.error if isinstance(event, IterationRunFailedEvent) else None | |||
| ) | |||
| ) | |||
| def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: | |||
| """ | |||
| Get workflow | |||
| """ | |||
| # fetch workflow by workflow_id | |||
| workflow = ( | |||
| db.session.query(Workflow) | |||
| .filter( | |||
| Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id | |||
| ) | |||
| .first() | |||
| ) | |||
| # return workflow | |||
| return workflow | |||
| def _publish_event(self, event: AppQueueEvent) -> None: | |||
| self.queue_manager.publish( | |||
| event, | |||
| PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| @@ -1,10 +1,24 @@ | |||
| from typing import Optional | |||
| from core.app.entities.queue_entities import AppQueueEvent | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| GraphEngineEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunStartedEvent, | |||
| GraphRunSucceededEvent, | |||
| IterationRunFailedEvent, | |||
| IterationRunNextEvent, | |||
| IterationRunStartedEvent, | |||
| IterationRunSucceededEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunStartedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| NodeRunSucceededEvent, | |||
| ParallelBranchRunFailedEvent, | |||
| ParallelBranchRunStartedEvent, | |||
| ParallelBranchRunSucceededEvent, | |||
| ) | |||
| _TEXT_COLOR_MAPPING = { | |||
| "blue": "36;1", | |||
| @@ -20,127 +34,203 @@ class WorkflowLoggingCallback(WorkflowCallback): | |||
| def __init__(self) -> None: | |||
| self.current_node_id = None | |||
| def on_workflow_run_started(self) -> None: | |||
| """ | |||
| Workflow run started | |||
| """ | |||
| self.print_text("\n[on_workflow_run_started]", color='pink') | |||
| def on_workflow_run_succeeded(self) -> None: | |||
| def on_event( | |||
| self, | |||
| event: GraphEngineEvent | |||
| ) -> None: | |||
| if isinstance(event, GraphRunStartedEvent): | |||
| self.print_text("\n[GraphRunStartedEvent]", color='pink') | |||
| elif isinstance(event, GraphRunSucceededEvent): | |||
| self.print_text("\n[GraphRunSucceededEvent]", color='green') | |||
| elif isinstance(event, GraphRunFailedEvent): | |||
| self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red') | |||
| elif isinstance(event, NodeRunStartedEvent): | |||
| self.on_workflow_node_execute_started( | |||
| event=event | |||
| ) | |||
| elif isinstance(event, NodeRunSucceededEvent): | |||
| self.on_workflow_node_execute_succeeded( | |||
| event=event | |||
| ) | |||
| elif isinstance(event, NodeRunFailedEvent): | |||
| self.on_workflow_node_execute_failed( | |||
| event=event | |||
| ) | |||
| elif isinstance(event, NodeRunStreamChunkEvent): | |||
| self.on_node_text_chunk( | |||
| event=event | |||
| ) | |||
| elif isinstance(event, ParallelBranchRunStartedEvent): | |||
| self.on_workflow_parallel_started( | |||
| event=event | |||
| ) | |||
| elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): | |||
| self.on_workflow_parallel_completed( | |||
| event=event | |||
| ) | |||
| elif isinstance(event, IterationRunStartedEvent): | |||
| self.on_workflow_iteration_started( | |||
| event=event | |||
| ) | |||
| elif isinstance(event, IterationRunNextEvent): | |||
| self.on_workflow_iteration_next( | |||
| event=event | |||
| ) | |||
| elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): | |||
| self.on_workflow_iteration_completed( | |||
| event=event | |||
| ) | |||
| else: | |||
| self.print_text(f"\n[{event.__class__.__name__}]", color='blue') | |||
| def on_workflow_node_execute_started( | |||
| self, | |||
| event: NodeRunStartedEvent | |||
| ) -> None: | |||
| """ | |||
| Workflow run succeeded | |||
| Workflow node execute started | |||
| """ | |||
| self.print_text("\n[on_workflow_run_succeeded]", color='green') | |||
| self.print_text("\n[NodeRunStartedEvent]", color='yellow') | |||
| self.print_text(f"Node ID: {event.node_id}", color='yellow') | |||
| self.print_text(f"Node Title: {event.node_data.title}", color='yellow') | |||
| self.print_text(f"Type: {event.node_type.value}", color='yellow') | |||
| def on_workflow_run_failed(self, error: str) -> None: | |||
| def on_workflow_node_execute_succeeded( | |||
| self, | |||
| event: NodeRunSucceededEvent | |||
| ) -> None: | |||
| """ | |||
| Workflow run failed | |||
| Workflow node execute succeeded | |||
| """ | |||
| self.print_text("\n[on_workflow_run_failed]", color='red') | |||
| def on_workflow_node_execute_started(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| node_run_index: int = 1, | |||
| predecessor_node_id: Optional[str] = None) -> None: | |||
| route_node_state = event.route_node_state | |||
| self.print_text("\n[NodeRunSucceededEvent]", color='green') | |||
| self.print_text(f"Node ID: {event.node_id}", color='green') | |||
| self.print_text(f"Node Title: {event.node_data.title}", color='green') | |||
| self.print_text(f"Type: {event.node_type.value}", color='green') | |||
| if route_node_state.node_run_result: | |||
| node_run_result = route_node_state.node_run_result | |||
| self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", | |||
| color='green') | |||
| self.print_text( | |||
| f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", | |||
| color='green') | |||
| self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", | |||
| color='green') | |||
| self.print_text( | |||
| f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", | |||
| color='green') | |||
| def on_workflow_node_execute_failed( | |||
| self, | |||
| event: NodeRunFailedEvent | |||
| ) -> None: | |||
| """ | |||
| Workflow node execute started | |||
| Workflow node execute failed | |||
| """ | |||
| self.print_text("\n[on_workflow_node_execute_started]", color='yellow') | |||
| self.print_text(f"Node ID: {node_id}", color='yellow') | |||
| self.print_text(f"Type: {node_type.value}", color='yellow') | |||
| self.print_text(f"Index: {node_run_index}", color='yellow') | |||
| if predecessor_node_id: | |||
| self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow') | |||
| def on_workflow_node_execute_succeeded(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| inputs: Optional[dict] = None, | |||
| process_data: Optional[dict] = None, | |||
| outputs: Optional[dict] = None, | |||
| execution_metadata: Optional[dict] = None) -> None: | |||
| route_node_state = event.route_node_state | |||
| self.print_text("\n[NodeRunFailedEvent]", color='red') | |||
| self.print_text(f"Node ID: {event.node_id}", color='red') | |||
| self.print_text(f"Node Title: {event.node_data.title}", color='red') | |||
| self.print_text(f"Type: {event.node_type.value}", color='red') | |||
| if route_node_state.node_run_result: | |||
| node_run_result = route_node_state.node_run_result | |||
| self.print_text(f"Error: {node_run_result.error}", color='red') | |||
| self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", | |||
| color='red') | |||
| self.print_text( | |||
| f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", | |||
| color='red') | |||
| self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", | |||
| color='red') | |||
| def on_node_text_chunk( | |||
| self, | |||
| event: NodeRunStreamChunkEvent | |||
| ) -> None: | |||
| """ | |||
| Workflow node execute succeeded | |||
| Publish text chunk | |||
| """ | |||
| self.print_text("\n[on_workflow_node_execute_succeeded]", color='green') | |||
| self.print_text(f"Node ID: {node_id}", color='green') | |||
| self.print_text(f"Type: {node_type.value}", color='green') | |||
| self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green') | |||
| self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green') | |||
| self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green') | |||
| self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}", | |||
| color='green') | |||
| def on_workflow_node_execute_failed(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| error: str, | |||
| inputs: Optional[dict] = None, | |||
| outputs: Optional[dict] = None, | |||
| process_data: Optional[dict] = None) -> None: | |||
| route_node_state = event.route_node_state | |||
| if not self.current_node_id or self.current_node_id != route_node_state.node_id: | |||
| self.current_node_id = route_node_state.node_id | |||
| self.print_text('\n[NodeRunStreamChunkEvent]') | |||
| self.print_text(f"Node ID: {route_node_state.node_id}") | |||
| node_run_result = route_node_state.node_run_result | |||
| if node_run_result: | |||
| self.print_text( | |||
| f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}") | |||
| self.print_text(event.chunk_content, color="pink", end="") | |||
| def on_workflow_parallel_started( | |||
| self, | |||
| event: ParallelBranchRunStartedEvent | |||
| ) -> None: | |||
| """ | |||
| Workflow node execute failed | |||
| Publish parallel started | |||
| """ | |||
| self.print_text("\n[on_workflow_node_execute_failed]", color='red') | |||
| self.print_text(f"Node ID: {node_id}", color='red') | |||
| self.print_text(f"Type: {node_type.value}", color='red') | |||
| self.print_text(f"Error: {error}", color='red') | |||
| self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red') | |||
| self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red') | |||
| self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red') | |||
| self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue') | |||
| self.print_text(f"Parallel ID: {event.parallel_id}", color='blue') | |||
| self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue') | |||
| if event.in_iteration_id: | |||
| self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue') | |||
| def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: | |||
| def on_workflow_parallel_completed( | |||
| self, | |||
| event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent | |||
| ) -> None: | |||
| """ | |||
| Publish text chunk | |||
| Publish parallel completed | |||
| """ | |||
| if not self.current_node_id or self.current_node_id != node_id: | |||
| self.current_node_id = node_id | |||
| self.print_text('\n[on_node_text_chunk]') | |||
| self.print_text(f"Node ID: {node_id}") | |||
| self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}") | |||
| if isinstance(event, ParallelBranchRunSucceededEvent): | |||
| color = 'blue' | |||
| elif isinstance(event, ParallelBranchRunFailedEvent): | |||
| color = 'red' | |||
| self.print_text(text, color="pink", end="") | |||
| self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color) | |||
| self.print_text(f"Parallel ID: {event.parallel_id}", color=color) | |||
| self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color) | |||
| if event.in_iteration_id: | |||
| self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color) | |||
| def on_workflow_iteration_started(self, | |||
| node_id: str, | |||
| node_type: NodeType, | |||
| node_run_index: int = 1, | |||
| node_data: Optional[BaseNodeData] = None, | |||
| inputs: dict = None, | |||
| predecessor_node_id: Optional[str] = None, | |||
| metadata: Optional[dict] = None) -> None: | |||
| if isinstance(event, ParallelBranchRunFailedEvent): | |||
| self.print_text(f"Error: {event.error}", color=color) | |||
| def on_workflow_iteration_started( | |||
| self, | |||
| event: IterationRunStartedEvent | |||
| ) -> None: | |||
| """ | |||
| Publish iteration started | |||
| """ | |||
| self.print_text("\n[on_workflow_iteration_started]", color='blue') | |||
| self.print_text(f"Node ID: {node_id}", color='blue') | |||
| self.print_text("\n[IterationRunStartedEvent]", color='blue') | |||
| self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') | |||
| def on_workflow_iteration_next(self, node_id: str, | |||
| node_type: NodeType, | |||
| index: int, | |||
| node_run_index: int, | |||
| output: Optional[dict]) -> None: | |||
| def on_workflow_iteration_next( | |||
| self, | |||
| event: IterationRunNextEvent | |||
| ) -> None: | |||
| """ | |||
| Publish iteration next | |||
| """ | |||
| self.print_text("\n[on_workflow_iteration_next]", color='blue') | |||
| self.print_text("\n[IterationRunNextEvent]", color='blue') | |||
| self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') | |||
| self.print_text(f"Iteration Index: {event.index}", color='blue') | |||
| def on_workflow_iteration_completed(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_run_index: int, | |||
| outputs: dict) -> None: | |||
| def on_workflow_iteration_completed( | |||
| self, | |||
| event: IterationRunSucceededEvent | IterationRunFailedEvent | |||
| ) -> None: | |||
| """ | |||
| Publish iteration completed | |||
| """ | |||
| self.print_text("\n[on_workflow_iteration_completed]", color='blue') | |||
| def on_event(self, event: AppQueueEvent) -> None: | |||
| """ | |||
| Publish event | |||
| """ | |||
| self.print_text("\n[on_workflow_event]", color='blue') | |||
| self.print_text(f"Event: {jsonable_encoder(event)}", color='blue') | |||
| self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue') | |||
| self.print_text(f"Node ID: {event.iteration_id}", color='blue') | |||
| def print_text( | |||
| self, text: str, color: Optional[str] = None, end: str = "\n" | |||
| @@ -1,3 +1,4 @@ | |||
| from datetime import datetime | |||
| from enum import Enum | |||
| from typing import Any, Optional | |||
| @@ -5,7 +6,8 @@ from pydantic import BaseModel, field_validator | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| class QueueEvent(str, Enum): | |||
| @@ -31,6 +33,9 @@ class QueueEvent(str, Enum): | |||
| ANNOTATION_REPLY = "annotation_reply" | |||
| AGENT_THOUGHT = "agent_thought" | |||
| MESSAGE_FILE = "message_file" | |||
| PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" | |||
| PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" | |||
| PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" | |||
| ERROR = "error" | |||
| PING = "ping" | |||
| STOP = "stop" | |||
| @@ -38,7 +43,7 @@ class QueueEvent(str, Enum): | |||
| class AppQueueEvent(BaseModel): | |||
| """ | |||
| QueueEvent entity | |||
| QueueEvent abstract entity | |||
| """ | |||
| event: QueueEvent | |||
| @@ -46,6 +51,7 @@ class AppQueueEvent(BaseModel): | |||
| class QueueLLMChunkEvent(AppQueueEvent): | |||
| """ | |||
| QueueLLMChunkEvent entity | |||
| Only for basic mode apps | |||
| """ | |||
| event: QueueEvent = QueueEvent.LLM_CHUNK | |||
| chunk: LLMResultChunk | |||
| @@ -55,14 +61,24 @@ class QueueIterationStartEvent(AppQueueEvent): | |||
| QueueIterationStartEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.ITERATION_START | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| start_at: datetime | |||
| node_run_index: int | |||
| inputs: dict = None | |||
| inputs: Optional[dict[str, Any]] = None | |||
| predecessor_node_id: Optional[str] = None | |||
| metadata: Optional[dict] = None | |||
| metadata: Optional[dict[str, Any]] = None | |||
| class QueueIterationNextEvent(AppQueueEvent): | |||
| """ | |||
| @@ -71,8 +87,18 @@ class QueueIterationNextEvent(AppQueueEvent): | |||
| event: QueueEvent = QueueEvent.ITERATION_NEXT | |||
| index: int | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| node_run_index: int | |||
| output: Optional[Any] = None # output for the current iteration | |||
| @@ -93,13 +119,30 @@ class QueueIterationCompletedEvent(AppQueueEvent): | |||
| """ | |||
| QueueIterationCompletedEvent entity | |||
| """ | |||
| event:QueueEvent = QueueEvent.ITERATION_COMPLETED | |||
| event: QueueEvent = QueueEvent.ITERATION_COMPLETED | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| start_at: datetime | |||
| node_run_index: int | |||
| outputs: dict | |||
| inputs: Optional[dict[str, Any]] = None | |||
| outputs: Optional[dict[str, Any]] = None | |||
| metadata: Optional[dict[str, Any]] = None | |||
| steps: int = 0 | |||
| error: Optional[str] = None | |||
| class QueueTextChunkEvent(AppQueueEvent): | |||
| """ | |||
| @@ -107,7 +150,10 @@ class QueueTextChunkEvent(AppQueueEvent): | |||
| """ | |||
| event: QueueEvent = QueueEvent.TEXT_CHUNK | |||
| text: str | |||
| metadata: Optional[dict] = None | |||
| from_variable_selector: Optional[list[str]] = None | |||
| """from variable selector""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| class QueueAgentMessageEvent(AppQueueEvent): | |||
| @@ -132,6 +178,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): | |||
| """ | |||
| event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES | |||
| retriever_resources: list[dict] | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| class QueueAnnotationReplyEvent(AppQueueEvent): | |||
| @@ -162,6 +210,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent): | |||
| QueueWorkflowStartedEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.WORKFLOW_STARTED | |||
| graph_runtime_state: GraphRuntimeState | |||
| class QueueWorkflowSucceededEvent(AppQueueEvent): | |||
| @@ -169,6 +218,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent): | |||
| QueueWorkflowSucceededEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED | |||
| outputs: Optional[dict[str, Any]] = None | |||
| class QueueWorkflowFailedEvent(AppQueueEvent): | |||
| @@ -185,11 +235,23 @@ class QueueNodeStartedEvent(AppQueueEvent): | |||
| """ | |||
| event: QueueEvent = QueueEvent.NODE_STARTED | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| node_run_index: int = 1 | |||
| predecessor_node_id: Optional[str] = None | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| start_at: datetime | |||
| class QueueNodeSucceededEvent(AppQueueEvent): | |||
| @@ -198,14 +260,26 @@ class QueueNodeSucceededEvent(AppQueueEvent): | |||
| """ | |||
| event: QueueEvent = QueueEvent.NODE_SUCCEEDED | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| inputs: Optional[dict] = None | |||
| process_data: Optional[dict] = None | |||
| outputs: Optional[dict] = None | |||
| execution_metadata: Optional[dict] = None | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| start_at: datetime | |||
| inputs: Optional[dict[str, Any]] = None | |||
| process_data: Optional[dict[str, Any]] = None | |||
| outputs: Optional[dict[str, Any]] = None | |||
| execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None | |||
| error: Optional[str] = None | |||
| @@ -216,13 +290,25 @@ class QueueNodeFailedEvent(AppQueueEvent): | |||
| """ | |||
| event: QueueEvent = QueueEvent.NODE_FAILED | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| inputs: Optional[dict] = None | |||
| outputs: Optional[dict] = None | |||
| process_data: Optional[dict] = None | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| start_at: datetime | |||
| inputs: Optional[dict[str, Any]] = None | |||
| process_data: Optional[dict[str, Any]] = None | |||
| outputs: Optional[dict[str, Any]] = None | |||
| error: str | |||
| @@ -274,10 +360,23 @@ class QueueStopEvent(AppQueueEvent): | |||
| event: QueueEvent = QueueEvent.STOP | |||
| stopped_by: StopBy | |||
| def get_stop_reason(self) -> str: | |||
| """ | |||
| To stop reason | |||
| """ | |||
| reason_mapping = { | |||
| QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.', | |||
| QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.', | |||
| QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.', | |||
| QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.' | |||
| } | |||
| return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.') | |||
| class QueueMessage(BaseModel): | |||
| """ | |||
| QueueMessage entity | |||
| QueueMessage abstract entity | |||
| """ | |||
| task_id: str | |||
| app_mode: str | |||
| @@ -297,3 +396,52 @@ class WorkflowQueueMessage(QueueMessage): | |||
| WorkflowQueueMessage entity | |||
| """ | |||
| pass | |||
| class QueueParallelBranchRunStartedEvent(AppQueueEvent): | |||
| """ | |||
| QueueParallelBranchRunStartedEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED | |||
| parallel_id: str | |||
| parallel_start_node_id: str | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| class QueueParallelBranchRunSucceededEvent(AppQueueEvent): | |||
| """ | |||
| QueueParallelBranchRunSucceededEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED | |||
| parallel_id: str | |||
| parallel_start_node_id: str | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| class QueueParallelBranchRunFailedEvent(AppQueueEvent): | |||
| """ | |||
| QueueParallelBranchRunFailedEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED | |||
| parallel_id: str | |||
| parallel_start_node_id: str | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| error: str | |||
| @@ -3,40 +3,11 @@ from typing import Any, Optional | |||
| from pydantic import BaseModel, ConfigDict | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||
| from core.model_runtime.entities.llm_entities import LLMResult | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.nodes.answer.entities import GenerateRouteChunk | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class WorkflowStreamGenerateNodes(BaseModel): | |||
| """ | |||
| WorkflowStreamGenerateNodes entity | |||
| """ | |||
| end_node_id: str | |||
| stream_node_ids: list[str] | |||
| class ChatflowStreamGenerateRoute(BaseModel): | |||
| """ | |||
| ChatflowStreamGenerateRoute entity | |||
| """ | |||
| answer_node_id: str | |||
| generate_route: list[GenerateRouteChunk] | |||
| current_route_position: int = 0 | |||
| class NodeExecutionInfo(BaseModel): | |||
| """ | |||
| NodeExecutionInfo entity | |||
| """ | |||
| workflow_node_execution_id: str | |||
| node_type: NodeType | |||
| start_at: float | |||
| class TaskState(BaseModel): | |||
| """ | |||
| TaskState entity | |||
| @@ -57,27 +28,6 @@ class WorkflowTaskState(TaskState): | |||
| """ | |||
| answer: str = "" | |||
| workflow_run_id: Optional[str] = None | |||
| start_at: Optional[float] = None | |||
| total_tokens: int = 0 | |||
| total_steps: int = 0 | |||
| ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} | |||
| latest_node_execution_info: Optional[NodeExecutionInfo] = None | |||
| current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None | |||
| iteration_nested_node_ids: list[str] = None | |||
| class AdvancedChatTaskState(WorkflowTaskState): | |||
| """ | |||
| AdvancedChatTaskState entity | |||
| """ | |||
| usage: LLMUsage | |||
| current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None | |||
| class StreamEvent(Enum): | |||
| """ | |||
| @@ -97,6 +47,8 @@ class StreamEvent(Enum): | |||
| WORKFLOW_FINISHED = "workflow_finished" | |||
| NODE_STARTED = "node_started" | |||
| NODE_FINISHED = "node_finished" | |||
| PARALLEL_BRANCH_STARTED = "parallel_branch_started" | |||
| PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" | |||
| ITERATION_STARTED = "iteration_started" | |||
| ITERATION_NEXT = "iteration_next" | |||
| ITERATION_COMPLETED = "iteration_completed" | |||
| @@ -267,6 +219,11 @@ class NodeStartStreamResponse(StreamResponse): | |||
| inputs: Optional[dict] = None | |||
| created_at: int | |||
| extras: dict = {} | |||
| parallel_id: Optional[str] = None | |||
| parallel_start_node_id: Optional[str] = None | |||
| parent_parallel_id: Optional[str] = None | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| iteration_id: Optional[str] = None | |||
| event: StreamEvent = StreamEvent.NODE_STARTED | |||
| workflow_run_id: str | |||
| @@ -286,7 +243,12 @@ class NodeStartStreamResponse(StreamResponse): | |||
| "predecessor_node_id": self.data.predecessor_node_id, | |||
| "inputs": None, | |||
| "created_at": self.data.created_at, | |||
| "extras": {} | |||
| "extras": {}, | |||
| "parallel_id": self.data.parallel_id, | |||
| "parallel_start_node_id": self.data.parallel_start_node_id, | |||
| "parent_parallel_id": self.data.parent_parallel_id, | |||
| "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, | |||
| "iteration_id": self.data.iteration_id, | |||
| } | |||
| } | |||
| @@ -316,6 +278,11 @@ class NodeFinishStreamResponse(StreamResponse): | |||
| created_at: int | |||
| finished_at: int | |||
| files: Optional[list[dict]] = [] | |||
| parallel_id: Optional[str] = None | |||
| parallel_start_node_id: Optional[str] = None | |||
| parent_parallel_id: Optional[str] = None | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| iteration_id: Optional[str] = None | |||
| event: StreamEvent = StreamEvent.NODE_FINISHED | |||
| workflow_run_id: str | |||
| @@ -342,9 +309,58 @@ class NodeFinishStreamResponse(StreamResponse): | |||
| "execution_metadata": None, | |||
| "created_at": self.data.created_at, | |||
| "finished_at": self.data.finished_at, | |||
| "files": [] | |||
| "files": [], | |||
| "parallel_id": self.data.parallel_id, | |||
| "parallel_start_node_id": self.data.parallel_start_node_id, | |||
| "parent_parallel_id": self.data.parent_parallel_id, | |||
| "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, | |||
| "iteration_id": self.data.iteration_id, | |||
| } | |||
| } | |||
| class ParallelBranchStartStreamResponse(StreamResponse): | |||
| """ | |||
| ParallelBranchStartStreamResponse entity | |||
| """ | |||
| class Data(BaseModel): | |||
| """ | |||
| Data entity | |||
| """ | |||
| parallel_id: str | |||
| parallel_branch_id: str | |||
| parent_parallel_id: Optional[str] = None | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| iteration_id: Optional[str] = None | |||
| created_at: int | |||
| event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED | |||
| workflow_run_id: str | |||
| data: Data | |||
| class ParallelBranchFinishedStreamResponse(StreamResponse): | |||
| """ | |||
| ParallelBranchFinishedStreamResponse entity | |||
| """ | |||
| class Data(BaseModel): | |||
| """ | |||
| Data entity | |||
| """ | |||
| parallel_id: str | |||
| parallel_branch_id: str | |||
| parent_parallel_id: Optional[str] = None | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| iteration_id: Optional[str] = None | |||
| status: str | |||
| error: Optional[str] = None | |||
| created_at: int | |||
| event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED | |||
| workflow_run_id: str | |||
| data: Data | |||
| class IterationNodeStartStreamResponse(StreamResponse): | |||
| @@ -364,6 +380,8 @@ class IterationNodeStartStreamResponse(StreamResponse): | |||
| extras: dict = {} | |||
| metadata: dict = {} | |||
| inputs: dict = {} | |||
| parallel_id: Optional[str] = None | |||
| parallel_start_node_id: Optional[str] = None | |||
| event: StreamEvent = StreamEvent.ITERATION_STARTED | |||
| workflow_run_id: str | |||
| @@ -387,6 +405,8 @@ class IterationNodeNextStreamResponse(StreamResponse): | |||
| created_at: int | |||
| pre_iteration_output: Optional[Any] = None | |||
| extras: dict = {} | |||
| parallel_id: Optional[str] = None | |||
| parallel_start_node_id: Optional[str] = None | |||
| event: StreamEvent = StreamEvent.ITERATION_NEXT | |||
| workflow_run_id: str | |||
| @@ -408,8 +428,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse): | |||
| title: str | |||
| outputs: Optional[dict] = None | |||
| created_at: int | |||
| extras: dict = None | |||
| inputs: dict = None | |||
| extras: Optional[dict] = None | |||
| inputs: Optional[dict] = None | |||
| status: WorkflowNodeExecutionStatus | |||
| error: Optional[str] = None | |||
| elapsed_time: float | |||
| @@ -417,6 +437,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse): | |||
| execution_metadata: Optional[dict] = None | |||
| finished_at: int | |||
| steps: int | |||
| parallel_id: Optional[str] = None | |||
| parallel_start_node_id: Optional[str] = None | |||
| event: StreamEvent = StreamEvent.ITERATION_COMPLETED | |||
| workflow_run_id: str | |||
| @@ -488,7 +510,7 @@ class WorkflowAppStreamResponse(AppStreamResponse): | |||
| """ | |||
| WorkflowAppStreamResponse entity | |||
| """ | |||
| workflow_run_id: str | |||
| workflow_run_id: Optional[str] = None | |||
| class AppBlockingResponse(BaseModel): | |||
| @@ -562,25 +584,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): | |||
| workflow_run_id: str | |||
| data: Data | |||
| class WorkflowIterationState(BaseModel): | |||
| """ | |||
| WorkflowIterationState entity | |||
| """ | |||
| class Data(BaseModel): | |||
| """ | |||
| Data entity | |||
| """ | |||
| parent_iteration_id: Optional[str] = None | |||
| iteration_id: str | |||
| current_index: int | |||
| iteration_steps_boundary: list[int] = None | |||
| node_execution_id: str | |||
| started_at: float | |||
| inputs: dict = None | |||
| total_tokens: int = 0 | |||
| node_data: BaseNodeData | |||
| current_iterations: dict[str, Data] = None | |||
| @@ -68,16 +68,18 @@ class BasedGenerateTaskPipeline: | |||
| err = Exception(e.description if getattr(e, 'description', None) is not None else str(e)) | |||
| if message: | |||
| message = db.session.query(Message).filter(Message.id == message.id).first() | |||
| err_desc = self._error_to_desc(err) | |||
| message.status = 'error' | |||
| message.error = err_desc | |||
| refetch_message = db.session.query(Message).filter(Message.id == message.id).first() | |||
| db.session.commit() | |||
| if refetch_message: | |||
| err_desc = self._error_to_desc(err) | |||
| refetch_message.status = 'error' | |||
| refetch_message.error = err_desc | |||
| db.session.commit() | |||
| return err | |||
| def _error_to_desc(cls, e: Exception) -> str: | |||
| def _error_to_desc(self, e: Exception) -> str: | |||
| """ | |||
| Error to desc. | |||
| :param e: exception | |||
| @@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import ( | |||
| AgentChatAppGenerateEntity, | |||
| ChatAppGenerateEntity, | |||
| CompletionAppGenerateEntity, | |||
| InvokeFrom, | |||
| ) | |||
| from core.app.entities.queue_entities import ( | |||
| QueueAnnotationReplyEvent, | |||
| @@ -16,11 +15,11 @@ from core.app.entities.queue_entities import ( | |||
| QueueRetrieverResourcesEvent, | |||
| ) | |||
| from core.app.entities.task_entities import ( | |||
| AdvancedChatTaskState, | |||
| EasyUITaskState, | |||
| MessageFileStreamResponse, | |||
| MessageReplaceStreamResponse, | |||
| MessageStreamResponse, | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.llm_generator.llm_generator import LLMGenerator | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| @@ -36,7 +35,7 @@ class MessageCycleManage: | |||
| AgentChatAppGenerateEntity, | |||
| AdvancedChatAppGenerateEntity | |||
| ] | |||
| _task_state: Union[EasyUITaskState, AdvancedChatTaskState] | |||
| _task_state: Union[EasyUITaskState, WorkflowTaskState] | |||
| def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: | |||
| """ | |||
| @@ -45,6 +44,9 @@ class MessageCycleManage: | |||
| :param query: query | |||
| :return: thread | |||
| """ | |||
| if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): | |||
| return None | |||
| is_first_message = self._application_generate_entity.conversation_id is None | |||
| extras = self._application_generate_entity.extras | |||
| auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) | |||
| @@ -52,7 +54,7 @@ class MessageCycleManage: | |||
| if auto_generate_conversation_name and is_first_message: | |||
| # start generate thread | |||
| thread = Thread(target=self._generate_conversation_name_worker, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'flask_app': current_app._get_current_object(), # type: ignore | |||
| 'conversation_id': conversation.id, | |||
| 'query': query | |||
| }) | |||
| @@ -75,6 +77,9 @@ class MessageCycleManage: | |||
| .first() | |||
| ) | |||
| if not conversation: | |||
| return | |||
| if conversation.mode != AppMode.COMPLETION.value: | |||
| app_model = conversation.app | |||
| if not app_model: | |||
| @@ -121,34 +126,13 @@ class MessageCycleManage: | |||
| if self._application_generate_entity.app_config.additional_features.show_retrieve_source: | |||
| self._task_state.metadata['retriever_resources'] = event.retriever_resources | |||
| def _get_response_metadata(self) -> dict: | |||
| """ | |||
| Get response metadata by invoke from. | |||
| :return: | |||
| """ | |||
| metadata = {} | |||
| # show_retrieve_source | |||
| if 'retriever_resources' in self._task_state.metadata: | |||
| metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] | |||
| # show annotation reply | |||
| if 'annotation_reply' in self._task_state.metadata: | |||
| metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] | |||
| # show usage | |||
| if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: | |||
| metadata['usage'] = self._task_state.metadata['usage'] | |||
| return metadata | |||
| def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: | |||
| """ | |||
| Message file to stream response. | |||
| :param event: event | |||
| :return: | |||
| """ | |||
| message_file: MessageFile = ( | |||
| message_file = ( | |||
| db.session.query(MessageFile) | |||
| .filter(MessageFile.id == event.message_file_id) | |||
| .first() | |||
| @@ -1,33 +1,41 @@ | |||
| import json | |||
| import time | |||
| from datetime import datetime, timezone | |||
| from typing import Optional, Union, cast | |||
| from typing import Any, Optional, Union, cast | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity | |||
| from core.app.entities.queue_entities import ( | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueStopEvent, | |||
| QueueWorkflowFailedEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| QueueParallelBranchRunStartedEvent, | |||
| QueueParallelBranchRunSucceededEvent, | |||
| ) | |||
| from core.app.entities.task_entities import ( | |||
| NodeExecutionInfo, | |||
| IterationNodeCompletedStreamResponse, | |||
| IterationNodeNextStreamResponse, | |||
| IterationNodeStartStreamResponse, | |||
| NodeFinishStreamResponse, | |||
| NodeStartStreamResponse, | |||
| ParallelBranchFinishedStreamResponse, | |||
| ParallelBranchStartStreamResponse, | |||
| WorkflowFinishStreamResponse, | |||
| WorkflowStartStreamResponse, | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage | |||
| from core.file.file_obj import FileVar | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.entities.trace_entity import TraceTaskName | |||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||
| from core.workflow.workflow_engine_manager import WorkflowEngineManager | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import EndUser | |||
| @@ -41,54 +49,56 @@ from models.workflow import ( | |||
| WorkflowRunStatus, | |||
| WorkflowRunTriggeredFrom, | |||
| ) | |||
| from services.workflow_service import WorkflowService | |||
| class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| def _init_workflow_run(self, workflow: Workflow, | |||
| triggered_from: WorkflowRunTriggeredFrom, | |||
| user: Union[Account, EndUser], | |||
| user_inputs: dict, | |||
| system_inputs: Optional[dict] = None) -> WorkflowRun: | |||
| """ | |||
| Init workflow run | |||
| :param workflow: Workflow instance | |||
| :param triggered_from: triggered from | |||
| :param user: account or end user | |||
| :param user_inputs: user variables inputs | |||
| :param system_inputs: system inputs, like: query, files | |||
| :return: | |||
| """ | |||
| max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ | |||
| .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ | |||
| .filter(WorkflowRun.app_id == workflow.app_id) \ | |||
| .scalar() or 0 | |||
| class WorkflowCycleManage: | |||
| _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] | |||
| _workflow: Workflow | |||
| _user: Union[Account, EndUser] | |||
| _task_state: WorkflowTaskState | |||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||
| def _handle_workflow_run_start(self) -> WorkflowRun: | |||
| max_sequence = ( | |||
| db.session.query(db.func.max(WorkflowRun.sequence_number)) | |||
| .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) | |||
| .filter(WorkflowRun.app_id == self._workflow.app_id) | |||
| .scalar() | |||
| or 0 | |||
| ) | |||
| new_sequence_number = max_sequence + 1 | |||
| inputs = {**user_inputs} | |||
| for key, value in (system_inputs or {}).items(): | |||
| inputs = {**self._application_generate_entity.inputs} | |||
| for key, value in (self._workflow_system_variables or {}).items(): | |||
| if key.value == 'conversation': | |||
| continue | |||
| inputs[f'sys.{key.value}'] = value | |||
| inputs = WorkflowEngineManager.handle_special_values(inputs) | |||
| inputs = WorkflowEntry.handle_special_values(inputs) | |||
| triggered_from= ( | |||
| WorkflowRunTriggeredFrom.DEBUGGING | |||
| if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER | |||
| else WorkflowRunTriggeredFrom.APP_RUN | |||
| ) | |||
| # init workflow run | |||
| workflow_run = WorkflowRun( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=workflow.app_id, | |||
| sequence_number=new_sequence_number, | |||
| workflow_id=workflow.id, | |||
| type=workflow.type, | |||
| triggered_from=triggered_from.value, | |||
| version=workflow.version, | |||
| graph=workflow.graph, | |||
| inputs=json.dumps(inputs), | |||
| status=WorkflowRunStatus.RUNNING.value, | |||
| created_by_role=(CreatedByRole.ACCOUNT.value | |||
| if isinstance(user, Account) else CreatedByRole.END_USER.value), | |||
| created_by=user.id | |||
| workflow_run = WorkflowRun() | |||
| workflow_run.tenant_id = self._workflow.tenant_id | |||
| workflow_run.app_id = self._workflow.app_id | |||
| workflow_run.sequence_number = new_sequence_number | |||
| workflow_run.workflow_id = self._workflow.id | |||
| workflow_run.type = self._workflow.type | |||
| workflow_run.triggered_from = triggered_from.value | |||
| workflow_run.version = self._workflow.version | |||
| workflow_run.graph = self._workflow.graph | |||
| workflow_run.inputs = json.dumps(inputs) | |||
| workflow_run.status = WorkflowRunStatus.RUNNING.value | |||
| workflow_run.created_by_role = ( | |||
| CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value | |||
| ) | |||
| workflow_run.created_by = self._user.id | |||
| db.session.add(workflow_run) | |||
| db.session.commit() | |||
| @@ -97,33 +107,37 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| return workflow_run | |||
| def _workflow_run_success( | |||
| self, workflow_run: WorkflowRun, | |||
| def _handle_workflow_run_success( | |||
| self, | |||
| workflow_run: WorkflowRun, | |||
| start_at: float, | |||
| total_tokens: int, | |||
| total_steps: int, | |||
| outputs: Optional[str] = None, | |||
| conversation_id: Optional[str] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> WorkflowRun: | |||
| """ | |||
| Workflow run success | |||
| :param workflow_run: workflow run | |||
| :param start_at: start time | |||
| :param total_tokens: total tokens | |||
| :param total_steps: total steps | |||
| :param outputs: outputs | |||
| :param conversation_id: conversation id | |||
| :return: | |||
| """ | |||
| workflow_run = self._refetch_workflow_run(workflow_run.id) | |||
| workflow_run.status = WorkflowRunStatus.SUCCEEDED.value | |||
| workflow_run.outputs = outputs | |||
| workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) | |||
| workflow_run.elapsed_time = time.perf_counter() - start_at | |||
| workflow_run.total_tokens = total_tokens | |||
| workflow_run.total_steps = total_steps | |||
| workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| db.session.commit() | |||
| db.session.refresh(workflow_run) | |||
| db.session.close() | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| @@ -135,34 +149,58 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| ) | |||
| ) | |||
| db.session.close() | |||
| return workflow_run | |||
| def _workflow_run_failed( | |||
| self, workflow_run: WorkflowRun, | |||
| def _handle_workflow_run_failed( | |||
| self, | |||
| workflow_run: WorkflowRun, | |||
| start_at: float, | |||
| total_tokens: int, | |||
| total_steps: int, | |||
| status: WorkflowRunStatus, | |||
| error: str, | |||
| conversation_id: Optional[str] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> WorkflowRun: | |||
| """ | |||
| Workflow run failed | |||
| :param workflow_run: workflow run | |||
| :param start_at: start time | |||
| :param total_tokens: total tokens | |||
| :param total_steps: total steps | |||
| :param status: status | |||
| :param error: error message | |||
| :return: | |||
| """ | |||
| workflow_run = self._refetch_workflow_run(workflow_run.id) | |||
| workflow_run.status = status.value | |||
| workflow_run.error = error | |||
| workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) | |||
| workflow_run.elapsed_time = time.perf_counter() - start_at | |||
| workflow_run.total_tokens = total_tokens | |||
| workflow_run.total_steps = total_steps | |||
| workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| db.session.commit() | |||
| running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter( | |||
| WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, | |||
| WorkflowNodeExecution.app_id == workflow_run.app_id, | |||
| WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, | |||
| WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, | |||
| WorkflowNodeExecution.workflow_run_id == workflow_run.id, | |||
| WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value | |||
| ).all() | |||
| for workflow_node_execution in running_workflow_node_executions: | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | |||
| workflow_node_execution.error = error | |||
| workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds() | |||
| db.session.commit() | |||
| db.session.refresh(workflow_run) | |||
| db.session.close() | |||
| @@ -178,39 +216,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| return workflow_run | |||
| def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, | |||
| node_id: str, | |||
| node_type: NodeType, | |||
| node_title: str, | |||
| node_run_index: int = 1, | |||
| predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution: | |||
| """ | |||
| Init workflow node execution from workflow run | |||
| :param workflow_run: workflow run | |||
| :param node_id: node id | |||
| :param node_type: node type | |||
| :param node_title: node title | |||
| :param node_run_index: run index | |||
| :param predecessor_node_id: predecessor node id if exists | |||
| :return: | |||
| """ | |||
| def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: | |||
| # init workflow node execution | |||
| workflow_node_execution = WorkflowNodeExecution( | |||
| tenant_id=workflow_run.tenant_id, | |||
| app_id=workflow_run.app_id, | |||
| workflow_id=workflow_run.workflow_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, | |||
| workflow_run_id=workflow_run.id, | |||
| predecessor_node_id=predecessor_node_id, | |||
| index=node_run_index, | |||
| node_id=node_id, | |||
| node_type=node_type.value, | |||
| title=node_title, | |||
| status=WorkflowNodeExecutionStatus.RUNNING.value, | |||
| created_by_role=workflow_run.created_by_role, | |||
| created_by=workflow_run.created_by, | |||
| created_at=datetime.now(timezone.utc).replace(tzinfo=None) | |||
| ) | |||
| workflow_node_execution = WorkflowNodeExecution() | |||
| workflow_node_execution.tenant_id = workflow_run.tenant_id | |||
| workflow_node_execution.app_id = workflow_run.app_id | |||
| workflow_node_execution.workflow_id = workflow_run.workflow_id | |||
| workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value | |||
| workflow_node_execution.workflow_run_id = workflow_run.id | |||
| workflow_node_execution.predecessor_node_id = event.predecessor_node_id | |||
| workflow_node_execution.index = event.node_run_index | |||
| workflow_node_execution.node_execution_id = event.node_execution_id | |||
| workflow_node_execution.node_id = event.node_id | |||
| workflow_node_execution.node_type = event.node_type.value | |||
| workflow_node_execution.title = event.node_data.title | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value | |||
| workflow_node_execution.created_by_role = workflow_run.created_by_role | |||
| workflow_node_execution.created_by = workflow_run.created_by | |||
| workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| db.session.add(workflow_node_execution) | |||
| db.session.commit() | |||
| @@ -219,33 +242,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| return workflow_node_execution | |||
| def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, | |||
| start_at: float, | |||
| inputs: Optional[dict] = None, | |||
| process_data: Optional[dict] = None, | |||
| outputs: Optional[dict] = None, | |||
| execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution: | |||
| def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: | |||
| """ | |||
| Workflow node execution success | |||
| :param workflow_node_execution: workflow node execution | |||
| :param start_at: start time | |||
| :param inputs: inputs | |||
| :param process_data: process data | |||
| :param outputs: outputs | |||
| :param execution_metadata: execution metadata | |||
| :param event: queue node succeeded event | |||
| :return: | |||
| """ | |||
| inputs = WorkflowEngineManager.handle_special_values(inputs) | |||
| outputs = WorkflowEngineManager.handle_special_values(outputs) | |||
| workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) | |||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | |||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value | |||
| workflow_node_execution.elapsed_time = time.perf_counter() - start_at | |||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||
| workflow_node_execution.process_data = json.dumps(process_data) if process_data else None | |||
| workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None | |||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||
| workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ | |||
| if execution_metadata else None | |||
| workflow_node_execution.execution_metadata = ( | |||
| json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None | |||
| ) | |||
| workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() | |||
| db.session.commit() | |||
| db.session.refresh(workflow_node_execution) | |||
| @@ -253,33 +269,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| return workflow_node_execution | |||
| def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, | |||
| start_at: float, | |||
| error: str, | |||
| inputs: Optional[dict] = None, | |||
| process_data: Optional[dict] = None, | |||
| outputs: Optional[dict] = None, | |||
| execution_metadata: Optional[dict] = None | |||
| ) -> WorkflowNodeExecution: | |||
| def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: | |||
| """ | |||
| Workflow node execution failed | |||
| :param workflow_node_execution: workflow node execution | |||
| :param start_at: start time | |||
| :param error: error message | |||
| :param event: queue node failed event | |||
| :return: | |||
| """ | |||
| inputs = WorkflowEngineManager.handle_special_values(inputs) | |||
| outputs = WorkflowEngineManager.handle_special_values(outputs) | |||
| workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) | |||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | |||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | |||
| workflow_node_execution.error = error | |||
| workflow_node_execution.elapsed_time = time.perf_counter() - start_at | |||
| workflow_node_execution.error = event.error | |||
| workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||
| workflow_node_execution.process_data = json.dumps(process_data) if process_data else None | |||
| workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None | |||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||
| workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ | |||
| if execution_metadata else None | |||
| workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() | |||
| db.session.commit() | |||
| db.session.refresh(workflow_node_execution) | |||
| @@ -287,8 +294,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| return workflow_node_execution | |||
| def _workflow_start_to_stream_response(self, task_id: str, | |||
| workflow_run: WorkflowRun) -> WorkflowStartStreamResponse: | |||
| ################################################# | |||
| # to stream responses # | |||
| ################################################# | |||
| def _workflow_start_to_stream_response( | |||
| self, task_id: str, workflow_run: WorkflowRun | |||
| ) -> WorkflowStartStreamResponse: | |||
| """ | |||
| Workflow start to stream response. | |||
| :param task_id: task id | |||
| @@ -302,13 +314,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| id=workflow_run.id, | |||
| workflow_id=workflow_run.workflow_id, | |||
| sequence_number=workflow_run.sequence_number, | |||
| inputs=workflow_run.inputs_dict, | |||
| created_at=int(workflow_run.created_at.timestamp()) | |||
| ) | |||
| inputs=workflow_run.inputs_dict or {}, | |||
| created_at=int(workflow_run.created_at.timestamp()), | |||
| ), | |||
| ) | |||
| def _workflow_finish_to_stream_response(self, task_id: str, | |||
| workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse: | |||
| def _workflow_finish_to_stream_response( | |||
| self, task_id: str, workflow_run: WorkflowRun | |||
| ) -> WorkflowFinishStreamResponse: | |||
| """ | |||
| Workflow finish to stream response. | |||
| :param task_id: task id | |||
| @@ -320,16 +333,16 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| created_by_account = workflow_run.created_by_account | |||
| if created_by_account: | |||
| created_by = { | |||
| "id": created_by_account.id, | |||
| "name": created_by_account.name, | |||
| "email": created_by_account.email, | |||
| 'id': created_by_account.id, | |||
| 'name': created_by_account.name, | |||
| 'email': created_by_account.email, | |||
| } | |||
| else: | |||
| created_by_end_user = workflow_run.created_by_end_user | |||
| if created_by_end_user: | |||
| created_by = { | |||
| "id": created_by_end_user.id, | |||
| "user": created_by_end_user.session_id, | |||
| 'id': created_by_end_user.id, | |||
| 'user': created_by_end_user.session_id, | |||
| } | |||
| return WorkflowFinishStreamResponse( | |||
| @@ -348,14 +361,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| created_by=created_by, | |||
| created_at=int(workflow_run.created_at.timestamp()), | |||
| finished_at=int(workflow_run.finished_at.timestamp()), | |||
| files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict) | |||
| ) | |||
| files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}), | |||
| ), | |||
| ) | |||
| def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, | |||
| task_id: str, | |||
| workflow_node_execution: WorkflowNodeExecution) \ | |||
| -> NodeStartStreamResponse: | |||
| def _workflow_node_start_to_stream_response( | |||
| self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution | |||
| ) -> Optional[NodeStartStreamResponse]: | |||
| """ | |||
| Workflow node start to stream response. | |||
| :param event: queue node started event | |||
| @@ -363,6 +375,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| :param workflow_node_execution: workflow node execution | |||
| :return: | |||
| """ | |||
| if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: | |||
| return None | |||
| response = NodeStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_node_execution.workflow_run_id, | |||
| @@ -374,8 +389,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| index=workflow_node_execution.index, | |||
| predecessor_node_id=workflow_node_execution.predecessor_node_id, | |||
| inputs=workflow_node_execution.inputs_dict, | |||
| created_at=int(workflow_node_execution.created_at.timestamp()) | |||
| ) | |||
| created_at=int(workflow_node_execution.created_at.timestamp()), | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| iteration_id=event.in_iteration_id, | |||
| ), | |||
| ) | |||
| # extras logic | |||
| @@ -384,19 +404,27 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| response.data.extras['icon'] = ToolManager.get_tool_icon( | |||
| tenant_id=self._application_generate_entity.app_config.tenant_id, | |||
| provider_type=node_data.provider_type, | |||
| provider_id=node_data.provider_id | |||
| provider_id=node_data.provider_id, | |||
| ) | |||
| return response | |||
| def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \ | |||
| -> NodeFinishStreamResponse: | |||
| def _workflow_node_finish_to_stream_response( | |||
| self, | |||
| event: QueueNodeSucceededEvent | QueueNodeFailedEvent, | |||
| task_id: str, | |||
| workflow_node_execution: WorkflowNodeExecution | |||
| ) -> Optional[NodeFinishStreamResponse]: | |||
| """ | |||
| Workflow node finish to stream response. | |||
| :param event: queue node succeeded or failed event | |||
| :param task_id: task id | |||
| :param workflow_node_execution: workflow node execution | |||
| :return: | |||
| """ | |||
| if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: | |||
| return None | |||
| return NodeFinishStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_node_execution.workflow_run_id, | |||
| @@ -416,181 +444,155 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| execution_metadata=workflow_node_execution.execution_metadata_dict, | |||
| created_at=int(workflow_node_execution.created_at.timestamp()), | |||
| finished_at=int(workflow_node_execution.finished_at.timestamp()), | |||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict) | |||
| ) | |||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| iteration_id=event.in_iteration_id, | |||
| ), | |||
| ) | |||
| def _handle_workflow_start(self) -> WorkflowRun: | |||
| self._task_state.start_at = time.perf_counter() | |||
| workflow_run = self._init_workflow_run( | |||
| workflow=self._workflow, | |||
| triggered_from=WorkflowRunTriggeredFrom.DEBUGGING | |||
| if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER | |||
| else WorkflowRunTriggeredFrom.APP_RUN, | |||
| user=self._user, | |||
| user_inputs=self._application_generate_entity.inputs, | |||
| system_inputs=self._workflow_system_variables | |||
| def _workflow_parallel_branch_start_to_stream_response( | |||
| self, | |||
| task_id: str, | |||
| workflow_run: WorkflowRun, | |||
| event: QueueParallelBranchRunStartedEvent | |||
| ) -> ParallelBranchStartStreamResponse: | |||
| """ | |||
| Workflow parallel branch start to stream response | |||
| :param task_id: task id | |||
| :param workflow_run: workflow run | |||
| :param event: parallel branch run started event | |||
| :return: | |||
| """ | |||
| return ParallelBranchStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| data=ParallelBranchStartStreamResponse.Data( | |||
| parallel_id=event.parallel_id, | |||
| parallel_branch_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| iteration_id=event.in_iteration_id, | |||
| created_at=int(time.time()), | |||
| ) | |||
| ) | |||
| self._task_state.workflow_run_id = workflow_run.id | |||
| db.session.close() | |||
| return workflow_run | |||
| def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: | |||
| workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() | |||
| workflow_node_execution = self._init_node_execution_from_workflow_run( | |||
| workflow_run=workflow_run, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_title=event.node_data.title, | |||
| node_run_index=event.node_run_index, | |||
| predecessor_node_id=event.predecessor_node_id | |||
| def _workflow_parallel_branch_finished_to_stream_response( | |||
| self, | |||
| task_id: str, | |||
| workflow_run: WorkflowRun, | |||
| event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent | |||
| ) -> ParallelBranchFinishedStreamResponse: | |||
| """ | |||
| Workflow parallel branch finished to stream response | |||
| :param task_id: task id | |||
| :param workflow_run: workflow run | |||
| :param event: parallel branch run succeeded or failed event | |||
| :return: | |||
| """ | |||
| return ParallelBranchFinishedStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| data=ParallelBranchFinishedStreamResponse.Data( | |||
| parallel_id=event.parallel_id, | |||
| parallel_branch_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| iteration_id=event.in_iteration_id, | |||
| status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed', | |||
| error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, | |||
| created_at=int(time.time()), | |||
| ) | |||
| ) | |||
| latest_node_execution_info = NodeExecutionInfo( | |||
| workflow_node_execution_id=workflow_node_execution.id, | |||
| node_type=event.node_type, | |||
| start_at=time.perf_counter() | |||
| def _workflow_iteration_start_to_stream_response( | |||
| self, | |||
| task_id: str, | |||
| workflow_run: WorkflowRun, | |||
| event: QueueIterationStartEvent | |||
| ) -> IterationNodeStartStreamResponse: | |||
| """ | |||
| Workflow iteration start to stream response | |||
| :param task_id: task id | |||
| :param workflow_run: workflow run | |||
| :param event: iteration start event | |||
| :return: | |||
| """ | |||
| return IterationNodeStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| data=IterationNodeStartStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| inputs=event.inputs or {}, | |||
| metadata=event.metadata or {}, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| ) | |||
| ) | |||
| self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info | |||
| self._task_state.latest_node_execution_info = latest_node_execution_info | |||
| self._task_state.total_steps += 1 | |||
| db.session.close() | |||
| return workflow_node_execution | |||
| def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: | |||
| current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] | |||
| workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( | |||
| WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() | |||
| execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None | |||
| if self._iteration_state and self._iteration_state.current_iterations: | |||
| if not execution_metadata: | |||
| execution_metadata = {} | |||
| current_iteration_data = None | |||
| for iteration_node_id in self._iteration_state.current_iterations: | |||
| data = self._iteration_state.current_iterations[iteration_node_id] | |||
| if data.parent_iteration_id == None: | |||
| current_iteration_data = data | |||
| break | |||
| if current_iteration_data: | |||
| execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id | |||
| execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index | |||
| if isinstance(event, QueueNodeSucceededEvent): | |||
| workflow_node_execution = self._workflow_node_execution_success( | |||
| workflow_node_execution=workflow_node_execution, | |||
| start_at=current_node_execution.start_at, | |||
| inputs=event.inputs, | |||
| process_data=event.process_data, | |||
| outputs=event.outputs, | |||
| execution_metadata=execution_metadata | |||
| def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse: | |||
| """ | |||
| Workflow iteration next to stream response | |||
| :param task_id: task id | |||
| :param workflow_run: workflow run | |||
| :param event: iteration next event | |||
| :return: | |||
| """ | |||
| return IterationNodeNextStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| data=IterationNodeNextStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| index=event.index, | |||
| pre_iteration_output=event.output, | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| ) | |||
| ) | |||
| if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): | |||
| self._task_state.total_tokens += ( | |||
| int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) | |||
| if self._iteration_state: | |||
| for iteration_node_id in self._iteration_state.current_iterations: | |||
| data = self._iteration_state.current_iterations[iteration_node_id] | |||
| if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): | |||
| data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) | |||
| if workflow_node_execution.node_type == NodeType.LLM.value: | |||
| outputs = workflow_node_execution.outputs_dict | |||
| usage_dict = outputs.get('usage', {}) | |||
| self._task_state.metadata['usage'] = usage_dict | |||
| else: | |||
| workflow_node_execution = self._workflow_node_execution_failed( | |||
| workflow_node_execution=workflow_node_execution, | |||
| start_at=current_node_execution.start_at, | |||
| error=event.error, | |||
| inputs=event.inputs, | |||
| process_data=event.process_data, | |||
| def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse: | |||
| """ | |||
| Workflow iteration completed to stream response | |||
| :param task_id: task id | |||
| :param workflow_run: workflow run | |||
| :param event: iteration completed event | |||
| :return: | |||
| """ | |||
| return IterationNodeCompletedStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| data=IterationNodeCompletedStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| outputs=event.outputs, | |||
| execution_metadata=execution_metadata | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| inputs=event.inputs or {}, | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| error=None, | |||
| elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), | |||
| total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0, | |||
| execution_metadata=event.metadata, | |||
| finished_at=int(time.time()), | |||
| steps=event.steps, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| ) | |||
| db.session.close() | |||
| return workflow_node_execution | |||
| def _handle_workflow_finished( | |||
| self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent, | |||
| conversation_id: Optional[str] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None | |||
| ) -> Optional[WorkflowRun]: | |||
| workflow_run = db.session.query(WorkflowRun).filter( | |||
| WorkflowRun.id == self._task_state.workflow_run_id).first() | |||
| if not workflow_run: | |||
| return None | |||
| if conversation_id is None: | |||
| conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id') | |||
| if isinstance(event, QueueStopEvent): | |||
| workflow_run = self._workflow_run_failed( | |||
| workflow_run=workflow_run, | |||
| total_tokens=self._task_state.total_tokens, | |||
| total_steps=self._task_state.total_steps, | |||
| status=WorkflowRunStatus.STOPPED, | |||
| error='Workflow stopped.', | |||
| conversation_id=conversation_id, | |||
| trace_manager=trace_manager | |||
| ) | |||
| latest_node_execution_info = self._task_state.latest_node_execution_info | |||
| if latest_node_execution_info: | |||
| workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( | |||
| WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first() | |||
| if (workflow_node_execution | |||
| and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value): | |||
| self._workflow_node_execution_failed( | |||
| workflow_node_execution=workflow_node_execution, | |||
| start_at=latest_node_execution_info.start_at, | |||
| error='Workflow stopped.' | |||
| ) | |||
| elif isinstance(event, QueueWorkflowFailedEvent): | |||
| workflow_run = self._workflow_run_failed( | |||
| workflow_run=workflow_run, | |||
| total_tokens=self._task_state.total_tokens, | |||
| total_steps=self._task_state.total_steps, | |||
| status=WorkflowRunStatus.FAILED, | |||
| error=event.error, | |||
| conversation_id=conversation_id, | |||
| trace_manager=trace_manager | |||
| ) | |||
| else: | |||
| if self._task_state.latest_node_execution_info: | |||
| workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( | |||
| WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() | |||
| outputs = workflow_node_execution.outputs | |||
| else: | |||
| outputs = None | |||
| workflow_run = self._workflow_run_success( | |||
| workflow_run=workflow_run, | |||
| total_tokens=self._task_state.total_tokens, | |||
| total_steps=self._task_state.total_steps, | |||
| outputs=outputs, | |||
| conversation_id=conversation_id, | |||
| trace_manager=trace_manager | |||
| ) | |||
| self._task_state.workflow_run_id = workflow_run.id | |||
| db.session.close() | |||
| return workflow_run | |||
| ) | |||
| def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: | |||
| """ | |||
| @@ -647,3 +649,40 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): | |||
| return value.to_dict() | |||
| return None | |||
| def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: | |||
| """ | |||
| Refetch workflow run | |||
| :param workflow_run_id: workflow run id | |||
| :return: | |||
| """ | |||
| workflow_run = db.session.query(WorkflowRun).filter( | |||
| WorkflowRun.id == workflow_run_id).first() | |||
| if not workflow_run: | |||
| raise Exception(f'Workflow run not found: {workflow_run_id}') | |||
| return workflow_run | |||
| def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: | |||
| """ | |||
| Refetch workflow node execution | |||
| :param node_execution_id: workflow node execution id | |||
| :return: | |||
| """ | |||
| workflow_node_execution = ( | |||
| db.session.query(WorkflowNodeExecution) | |||
| .filter( | |||
| WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id, | |||
| WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id, | |||
| WorkflowNodeExecution.workflow_id == self._workflow.id, | |||
| WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, | |||
| WorkflowNodeExecution.node_execution_id == node_execution_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not workflow_node_execution: | |||
| raise Exception(f'Workflow node execution not found: {node_execution_id}') | |||
| return workflow_node_execution | |||
| @@ -1,16 +0,0 @@ | |||
| from typing import Any, Union | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity | |||
| from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState | |||
| from core.workflow.enums import SystemVariableKey | |||
| from models.account import Account | |||
| from models.model import EndUser | |||
| from models.workflow import Workflow | |||
| class WorkflowCycleStateManager: | |||
| _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] | |||
| _workflow: Workflow | |||
| _user: Union[Account, EndUser] | |||
| _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] | |||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||
| @@ -1,290 +0,0 @@ | |||
| import json | |||
| import time | |||
| from collections.abc import Generator | |||
| from datetime import datetime, timezone | |||
| from typing import Optional, Union | |||
| from core.app.entities.queue_entities import ( | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| ) | |||
| from core.app.entities.task_entities import ( | |||
| IterationNodeCompletedStreamResponse, | |||
| IterationNodeNextStreamResponse, | |||
| IterationNodeStartStreamResponse, | |||
| NodeExecutionInfo, | |||
| WorkflowIterationState, | |||
| ) | |||
| from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.workflow_engine_manager import WorkflowEngineManager | |||
| from extensions.ext_database import db | |||
| from models.workflow import ( | |||
| WorkflowNodeExecution, | |||
| WorkflowNodeExecutionStatus, | |||
| WorkflowNodeExecutionTriggeredFrom, | |||
| WorkflowRun, | |||
| ) | |||
| class WorkflowIterationCycleManage(WorkflowCycleStateManager): | |||
| _iteration_state: WorkflowIterationState = None | |||
| def _init_iteration_state(self) -> WorkflowIterationState: | |||
| if not self._iteration_state: | |||
| self._iteration_state = WorkflowIterationState( | |||
| current_iterations={} | |||
| ) | |||
| def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \ | |||
| -> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]: | |||
| """ | |||
| Handle iteration to stream response | |||
| :param task_id: task id | |||
| :param event: iteration event | |||
| :return: | |||
| """ | |||
| if isinstance(event, QueueIterationStartEvent): | |||
| return IterationNodeStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=self._task_state.workflow_run_id, | |||
| data=IterationNodeStartStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| inputs=event.inputs, | |||
| metadata=event.metadata | |||
| ) | |||
| ) | |||
| elif isinstance(event, QueueIterationNextEvent): | |||
| current_iteration = self._iteration_state.current_iterations[event.node_id] | |||
| return IterationNodeNextStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=self._task_state.workflow_run_id, | |||
| data=IterationNodeNextStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=current_iteration.node_data.title, | |||
| index=event.index, | |||
| pre_iteration_output=event.output, | |||
| created_at=int(time.time()), | |||
| extras={} | |||
| ) | |||
| ) | |||
| elif isinstance(event, QueueIterationCompletedEvent): | |||
| current_iteration = self._iteration_state.current_iterations[event.node_id] | |||
| return IterationNodeCompletedStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=self._task_state.workflow_run_id, | |||
| data=IterationNodeCompletedStreamResponse.Data( | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=current_iteration.node_data.title, | |||
| outputs=event.outputs, | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| inputs=current_iteration.inputs, | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| error=None, | |||
| elapsed_time=time.perf_counter() - current_iteration.started_at, | |||
| total_tokens=current_iteration.total_tokens, | |||
| execution_metadata={ | |||
| 'total_tokens': current_iteration.total_tokens, | |||
| }, | |||
| finished_at=int(time.time()), | |||
| steps=current_iteration.current_index | |||
| ) | |||
| ) | |||
| def _init_iteration_execution_from_workflow_run(self, | |||
| workflow_run: WorkflowRun, | |||
| node_id: str, | |||
| node_type: NodeType, | |||
| node_title: str, | |||
| node_run_index: int = 1, | |||
| inputs: Optional[dict] = None, | |||
| predecessor_node_id: Optional[str] = None | |||
| ) -> WorkflowNodeExecution: | |||
| workflow_node_execution = WorkflowNodeExecution( | |||
| tenant_id=workflow_run.tenant_id, | |||
| app_id=workflow_run.app_id, | |||
| workflow_id=workflow_run.workflow_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, | |||
| workflow_run_id=workflow_run.id, | |||
| predecessor_node_id=predecessor_node_id, | |||
| index=node_run_index, | |||
| node_id=node_id, | |||
| node_type=node_type.value, | |||
| inputs=json.dumps(inputs) if inputs else None, | |||
| title=node_title, | |||
| status=WorkflowNodeExecutionStatus.RUNNING.value, | |||
| created_by_role=workflow_run.created_by_role, | |||
| created_by=workflow_run.created_by, | |||
| execution_metadata=json.dumps({ | |||
| 'started_run_index': node_run_index + 1, | |||
| 'current_index': 0, | |||
| 'steps_boundary': [], | |||
| }), | |||
| created_at=datetime.now(timezone.utc).replace(tzinfo=None) | |||
| ) | |||
| db.session.add(workflow_node_execution) | |||
| db.session.commit() | |||
| db.session.refresh(workflow_node_execution) | |||
| db.session.close() | |||
| return workflow_node_execution | |||
| def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution: | |||
| if isinstance(event, QueueIterationStartEvent): | |||
| return self._handle_iteration_started(event) | |||
| elif isinstance(event, QueueIterationNextEvent): | |||
| return self._handle_iteration_next(event) | |||
| elif isinstance(event, QueueIterationCompletedEvent): | |||
| return self._handle_iteration_completed(event) | |||
| def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution: | |||
| self._init_iteration_state() | |||
| workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() | |||
| workflow_node_execution = self._init_iteration_execution_from_workflow_run( | |||
| workflow_run=workflow_run, | |||
| node_id=event.node_id, | |||
| node_type=NodeType.ITERATION, | |||
| node_title=event.node_data.title, | |||
| node_run_index=event.node_run_index, | |||
| inputs=event.inputs, | |||
| predecessor_node_id=event.predecessor_node_id | |||
| ) | |||
| latest_node_execution_info = NodeExecutionInfo( | |||
| workflow_node_execution_id=workflow_node_execution.id, | |||
| node_type=NodeType.ITERATION, | |||
| start_at=time.perf_counter() | |||
| ) | |||
| self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info | |||
| self._task_state.latest_node_execution_info = latest_node_execution_info | |||
| self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data( | |||
| parent_iteration_id=None, | |||
| iteration_id=event.node_id, | |||
| current_index=0, | |||
| iteration_steps_boundary=[], | |||
| node_execution_id=workflow_node_execution.id, | |||
| started_at=time.perf_counter(), | |||
| inputs=event.inputs, | |||
| total_tokens=0, | |||
| node_data=event.node_data | |||
| ) | |||
| db.session.close() | |||
| return workflow_node_execution | |||
| def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution: | |||
| if event.node_id not in self._iteration_state.current_iterations: | |||
| return | |||
| current_iteration = self._iteration_state.current_iterations[event.node_id] | |||
| current_iteration.current_index = event.index | |||
| current_iteration.iteration_steps_boundary.append(event.node_run_index) | |||
| workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( | |||
| WorkflowNodeExecution.id == current_iteration.node_execution_id | |||
| ).first() | |||
| original_node_execution_metadata = workflow_node_execution.execution_metadata_dict | |||
| if original_node_execution_metadata: | |||
| original_node_execution_metadata['current_index'] = event.index | |||
| original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary | |||
| original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens | |||
| workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) | |||
| db.session.commit() | |||
| db.session.close() | |||
| def _handle_iteration_completed(self, event: QueueIterationCompletedEvent): | |||
| if event.node_id not in self._iteration_state.current_iterations: | |||
| return | |||
| current_iteration = self._iteration_state.current_iterations[event.node_id] | |||
| workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( | |||
| WorkflowNodeExecution.id == current_iteration.node_execution_id | |||
| ).first() | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value | |||
| workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None | |||
| workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at | |||
| original_node_execution_metadata = workflow_node_execution.execution_metadata_dict | |||
| if original_node_execution_metadata: | |||
| original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary | |||
| original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens | |||
| workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) | |||
| db.session.commit() | |||
| # remove current iteration | |||
| self._iteration_state.current_iterations.pop(event.node_id, None) | |||
| # set latest node execution info | |||
| latest_node_execution_info = NodeExecutionInfo( | |||
| workflow_node_execution_id=workflow_node_execution.id, | |||
| node_type=NodeType.ITERATION, | |||
| start_at=time.perf_counter() | |||
| ) | |||
| self._task_state.latest_node_execution_info = latest_node_execution_info | |||
| db.session.close() | |||
| def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]: | |||
| """ | |||
| Handle iteration exception | |||
| """ | |||
| if not self._iteration_state or not self._iteration_state.current_iterations: | |||
| return | |||
| for node_id, current_iteration in self._iteration_state.current_iterations.items(): | |||
| workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( | |||
| WorkflowNodeExecution.id == current_iteration.node_execution_id | |||
| ).first() | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | |||
| workflow_node_execution.error = error | |||
| workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at | |||
| db.session.commit() | |||
| db.session.close() | |||
| yield IterationNodeCompletedStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=self._task_state.workflow_run_id, | |||
| data=IterationNodeCompletedStreamResponse.Data( | |||
| id=node_id, | |||
| node_id=node_id, | |||
| node_type=NodeType.ITERATION.value, | |||
| title=current_iteration.node_data.title, | |||
| outputs={}, | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| inputs=current_iteration.inputs, | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=error, | |||
| elapsed_time=time.perf_counter() - current_iteration.started_at, | |||
| total_tokens=current_iteration.total_tokens, | |||
| execution_metadata={ | |||
| 'total_tokens': current_iteration.total_tokens, | |||
| }, | |||
| finished_at=int(time.time()), | |||
| steps=current_iteration.current_index | |||
| ) | |||
| ) | |||
| @@ -63,6 +63,39 @@ class LLMUsage(ModelUsage): | |||
| latency=0.0 | |||
| ) | |||
| def plus(self, other: 'LLMUsage') -> 'LLMUsage': | |||
| """ | |||
| Add two LLMUsage instances together. | |||
| :param other: Another LLMUsage instance to add | |||
| :return: A new LLMUsage instance with summed values | |||
| """ | |||
| if self.total_tokens == 0: | |||
| return other | |||
| else: | |||
| return LLMUsage( | |||
| prompt_tokens=self.prompt_tokens + other.prompt_tokens, | |||
| prompt_unit_price=other.prompt_unit_price, | |||
| prompt_price_unit=other.prompt_price_unit, | |||
| prompt_price=self.prompt_price + other.prompt_price, | |||
| completion_tokens=self.completion_tokens + other.completion_tokens, | |||
| completion_unit_price=other.completion_unit_price, | |||
| completion_price_unit=other.completion_price_unit, | |||
| completion_price=self.completion_price + other.completion_price, | |||
| total_tokens=self.total_tokens + other.total_tokens, | |||
| total_price=self.total_price + other.total_price, | |||
| currency=other.currency, | |||
| latency=self.latency + other.latency | |||
| ) | |||
| def __add__(self, other: 'LLMUsage') -> 'LLMUsage': | |||
| """ | |||
| Overload the + operator to add two LLMUsage instances. | |||
| :param other: Another LLMUsage instance to add | |||
| :return: A new LLMUsage instance with summed values | |||
| """ | |||
| return self.plus(other) | |||
| class LLMResult(BaseModel): | |||
| """ | |||
| @@ -34,13 +34,13 @@ class OutputModeration(BaseModel): | |||
| final_output: Optional[str] = None | |||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||
| def should_direct_output(self): | |||
| def should_direct_output(self) -> bool: | |||
| return self.final_output is not None | |||
| def get_final_output(self): | |||
| return self.final_output | |||
| def get_final_output(self) -> str: | |||
| return self.final_output or "" | |||
| def append_new_token(self, token: str): | |||
| def append_new_token(self, token: str) -> None: | |||
| self.buffer += token | |||
| if not self.thread: | |||
| @@ -1,7 +1,7 @@ | |||
| import json | |||
| import logging | |||
| from copy import deepcopy | |||
| from typing import Any, Union | |||
| from typing import Any, Optional, Union | |||
| from core.file.file_obj import FileTransferMethod, FileVar | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType | |||
| @@ -18,6 +18,7 @@ class WorkflowTool(Tool): | |||
| version: str | |||
| workflow_entities: dict[str, Any] | |||
| workflow_call_depth: int | |||
| thread_pool_id: Optional[str] = None | |||
| label: str | |||
| @@ -57,6 +58,7 @@ class WorkflowTool(Tool): | |||
| invoke_from=self.runtime.invoke_from, | |||
| stream=False, | |||
| call_depth=self.workflow_call_depth + 1, | |||
| workflow_thread_pool_id=self.thread_pool_id | |||
| ) | |||
| data = result.get('data', {}) | |||
| @@ -128,6 +128,7 @@ class ToolEngine: | |||
| user_id: str, | |||
| workflow_tool_callback: DifyWorkflowCallbackHandler, | |||
| workflow_call_depth: int, | |||
| thread_pool_id: Optional[str] = None | |||
| ) -> list[ToolInvokeMessage]: | |||
| """ | |||
| Workflow invokes the tool with the given arguments. | |||
| @@ -141,6 +142,7 @@ class ToolEngine: | |||
| if isinstance(tool, WorkflowTool): | |||
| tool.workflow_call_depth = workflow_call_depth + 1 | |||
| tool.thread_pool_id = thread_pool_id | |||
| if tool.runtime and tool.runtime.runtime_parameters: | |||
| tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} | |||
| @@ -25,7 +25,6 @@ from core.tools.tool.tool import Tool | |||
| from core.tools.tool_label_manager import ToolLabelManager | |||
| from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager | |||
| from core.tools.utils.tool_parameter_converter import ToolParameterConverter | |||
| from core.workflow.nodes.tool.entities import ToolEntity | |||
| from extensions.ext_database import db | |||
| from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider | |||
| from services.tools.tools_transform_service import ToolTransformService | |||
| @@ -249,7 +248,7 @@ class ToolManager: | |||
| return tool_entity | |||
| @classmethod | |||
| def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: | |||
| def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: | |||
| """ | |||
| get the workflow tool runtime | |||
| """ | |||
| @@ -7,6 +7,7 @@ from core.tools.tool_file_manager import ToolFileManager | |||
| logger = logging.getLogger(__name__) | |||
| class ToolFileMessageTransformer: | |||
| @classmethod | |||
| def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], | |||
| @@ -1,116 +1,15 @@ | |||
| from abc import ABC, abstractmethod | |||
| from typing import Any, Optional | |||
| from core.app.entities.queue_entities import AppQueueEvent | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.graph_engine.entities.event import GraphEngineEvent | |||
| class WorkflowCallback(ABC): | |||
| @abstractmethod | |||
| def on_workflow_run_started(self) -> None: | |||
| def on_event( | |||
| self, | |||
| event: GraphEngineEvent | |||
| ) -> None: | |||
| """ | |||
| Workflow run started | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def on_workflow_run_succeeded(self) -> None: | |||
| """ | |||
| Workflow run succeeded | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def on_workflow_run_failed(self, error: str) -> None: | |||
| """ | |||
| Workflow run failed | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def on_workflow_node_execute_started(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| node_run_index: int = 1, | |||
| predecessor_node_id: Optional[str] = None) -> None: | |||
| """ | |||
| Workflow node execute started | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def on_workflow_node_execute_succeeded(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| inputs: Optional[dict] = None, | |||
| process_data: Optional[dict] = None, | |||
| outputs: Optional[dict] = None, | |||
| execution_metadata: Optional[dict] = None) -> None: | |||
| """ | |||
| Workflow node execute succeeded | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def on_workflow_node_execute_failed(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData, | |||
| error: str, | |||
| inputs: Optional[dict] = None, | |||
| outputs: Optional[dict] = None, | |||
| process_data: Optional[dict] = None) -> None: | |||
| """ | |||
| Workflow node execute failed | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: | |||
| """ | |||
| Publish text chunk | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def on_workflow_iteration_started(self, | |||
| node_id: str, | |||
| node_type: NodeType, | |||
| node_run_index: int = 1, | |||
| node_data: Optional[BaseNodeData] = None, | |||
| inputs: Optional[dict] = None, | |||
| predecessor_node_id: Optional[str] = None, | |||
| metadata: Optional[dict] = None) -> None: | |||
| """ | |||
| Publish iteration started | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def on_workflow_iteration_next(self, node_id: str, | |||
| node_type: NodeType, | |||
| index: int, | |||
| node_run_index: int, | |||
| output: Optional[Any], | |||
| ) -> None: | |||
| """ | |||
| Publish iteration next | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def on_workflow_iteration_completed(self, node_id: str, | |||
| node_type: NodeType, | |||
| node_run_index: int, | |||
| outputs: dict) -> None: | |||
| """ | |||
| Publish iteration completed | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def on_event(self, event: AppQueueEvent) -> None: | |||
| """ | |||
| Publish event | |||
| Published event | |||
| """ | |||
| raise NotImplementedError | |||
| @@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel): | |||
| desc: Optional[str] = None | |||
| class BaseIterationNodeData(BaseNodeData): | |||
| start_node_id: str | |||
| start_node_id: Optional[str] = None | |||
| class BaseIterationState(BaseModel): | |||
| iteration_node_id: str | |||
| @@ -1,9 +1,9 @@ | |||
| from collections.abc import Mapping | |||
| from enum import Enum | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from models import WorkflowNodeExecutionStatus | |||
| @@ -28,6 +28,7 @@ class NodeType(Enum): | |||
| VARIABLE_ASSIGNER = 'variable-assigner' | |||
| LOOP = 'loop' | |||
| ITERATION = 'iteration' | |||
| ITERATION_START = 'iteration-start' # fake start node for iteration | |||
| PARAMETER_EXTRACTOR = 'parameter-extractor' | |||
| CONVERSATION_VARIABLE_ASSIGNER = 'assigner' | |||
| @@ -56,6 +57,10 @@ class NodeRunMetadataKey(Enum): | |||
| TOOL_INFO = 'tool_info' | |||
| ITERATION_ID = 'iteration_id' | |||
| ITERATION_INDEX = 'iteration_index' | |||
| PARALLEL_ID = 'parallel_id' | |||
| PARALLEL_START_NODE_ID = 'parallel_start_node_id' | |||
| PARENT_PARALLEL_ID = 'parent_parallel_id' | |||
| PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id' | |||
| class NodeRunResult(BaseModel): | |||
| @@ -65,11 +70,32 @@ class NodeRunResult(BaseModel): | |||
| status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING | |||
| inputs: Optional[Mapping[str, Any]] = None # node inputs | |||
| process_data: Optional[dict] = None # process data | |||
| outputs: Optional[Mapping[str, Any]] = None # node outputs | |||
| inputs: Optional[dict[str, Any]] = None # node inputs | |||
| process_data: Optional[dict[str, Any]] = None # process data | |||
| outputs: Optional[dict[str, Any]] = None # node outputs | |||
| metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata | |||
| llm_usage: Optional[LLMUsage] = None # llm usage | |||
| edge_source_handle: Optional[str] = None # source handle id of node with multiple branches | |||
| error: Optional[str] = None # error message if status is failed | |||
| class UserFrom(Enum): | |||
| """ | |||
| User from | |||
| """ | |||
| ACCOUNT = "account" | |||
| END_USER = "end-user" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "UserFrom": | |||
| """ | |||
| Value of | |||
| :param value: value | |||
| :return: | |||
| """ | |||
| for item in cls: | |||
| if item.value == value: | |||
| return item | |||
| raise ValueError(f"Invalid value: {value}") | |||
| @@ -2,6 +2,7 @@ from collections import defaultdict | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Union | |||
| from pydantic import BaseModel, Field, model_validator | |||
| from typing_extensions import deprecated | |||
| from core.app.segments import Segment, Variable, factory | |||
| @@ -16,43 +17,52 @@ ENVIRONMENT_VARIABLE_NODE_ID = "env" | |||
| CONVERSATION_VARIABLE_NODE_ID = "conversation" | |||
| class VariablePool: | |||
| def __init__( | |||
| self, | |||
| system_variables: Mapping[SystemVariableKey, Any], | |||
| user_inputs: Mapping[str, Any], | |||
| environment_variables: Sequence[Variable], | |||
| conversation_variables: Sequence[Variable] | None = None, | |||
| ) -> None: | |||
| # system variables | |||
| # for example: | |||
| # { | |||
| # 'query': 'abc', | |||
| # 'files': [] | |||
| # } | |||
| # Variable dictionary is a dictionary for looking up variables by their selector. | |||
| # The first element of the selector is the node id, it's the first-level key in the dictionary. | |||
| # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the | |||
| # elements of the selector except the first one. | |||
| self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict) | |||
| # TODO: This user inputs is not used for pool. | |||
| self.user_inputs = user_inputs | |||
| class VariablePool(BaseModel): | |||
| # Variable dictionary is a dictionary for looking up variables by their selector. | |||
| # The first element of the selector is the node id, it's the first-level key in the dictionary. | |||
| # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the | |||
| # elements of the selector except the first one. | |||
| variable_dictionary: dict[str, dict[int, Segment]] = Field( | |||
| description='Variables mapping', | |||
| default=defaultdict(dict) | |||
| ) | |||
| # TODO: This user inputs is not used for pool. | |||
| user_inputs: Mapping[str, Any] = Field( | |||
| description='User inputs', | |||
| ) | |||
| system_variables: Mapping[SystemVariableKey, Any] = Field( | |||
| description='System variables', | |||
| ) | |||
| environment_variables: Sequence[Variable] = Field( | |||
| description="Environment variables.", | |||
| default_factory=list | |||
| ) | |||
| conversation_variables: Sequence[Variable] | None = None | |||
| @model_validator(mode="after") | |||
| def val_model_after(self): | |||
| """ | |||
| Append system variables | |||
| :return: | |||
| """ | |||
| # Add system variables to the variable pool | |||
| self.system_variables = system_variables | |||
| for key, value in system_variables.items(): | |||
| for key, value in self.system_variables.items(): | |||
| self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) | |||
| # Add environment variables to the variable pool | |||
| for var in environment_variables: | |||
| for var in self.environment_variables or []: | |||
| self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) | |||
| # Add conversation variables to the variable pool | |||
| for var in conversation_variables or []: | |||
| for var in self.conversation_variables or []: | |||
| self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) | |||
| return self | |||
| def add(self, selector: Sequence[str], value: Any, /) -> None: | |||
| """ | |||
| Adds a variable to the variable pool. | |||
| @@ -79,7 +89,7 @@ class VariablePool: | |||
| v = factory.build_segment(value) | |||
| hash_key = hash(tuple(selector[1:])) | |||
| self._variable_dictionary[selector[0]][hash_key] = v | |||
| self.variable_dictionary[selector[0]][hash_key] = v | |||
| def get(self, selector: Sequence[str], /) -> Segment | None: | |||
| """ | |||
| @@ -97,7 +107,7 @@ class VariablePool: | |||
| if len(selector) < 2: | |||
| raise ValueError("Invalid selector") | |||
| hash_key = hash(tuple(selector[1:])) | |||
| value = self._variable_dictionary[selector[0]].get(hash_key) | |||
| value = self.variable_dictionary[selector[0]].get(hash_key) | |||
| return value | |||
| @@ -118,7 +128,7 @@ class VariablePool: | |||
| if len(selector) < 2: | |||
| raise ValueError("Invalid selector") | |||
| hash_key = hash(tuple(selector[1:])) | |||
| value = self._variable_dictionary[selector[0]].get(hash_key) | |||
| value = self.variable_dictionary[selector[0]].get(hash_key) | |||
| return value.to_object() if value else None | |||
| def remove(self, selector: Sequence[str], /): | |||
| @@ -134,7 +144,19 @@ class VariablePool: | |||
| if not selector: | |||
| return | |||
| if len(selector) == 1: | |||
| self._variable_dictionary[selector[0]] = {} | |||
| self.variable_dictionary[selector[0]] = {} | |||
| return | |||
| hash_key = hash(tuple(selector[1:])) | |||
| self._variable_dictionary[selector[0]].pop(hash_key, None) | |||
| self.variable_dictionary[selector[0]].pop(hash_key, None) | |||
| def remove_node(self, node_id: str, /): | |||
| """ | |||
| Remove all variables associated with a given node id. | |||
| Args: | |||
| node_id (str): The node id to remove. | |||
| Returns: | |||
| None | |||
| """ | |||
| self.variable_dictionary.pop(node_id, None) | |||
| @@ -66,8 +66,7 @@ class WorkflowRunState: | |||
| self.variable_pool = variable_pool | |||
| self.total_tokens = 0 | |||
| self.workflow_nodes_and_results = [] | |||
| self.current_iteration_state = None | |||
| self.workflow_node_steps = 1 | |||
| self.workflow_node_runs = [] | |||
| self.workflow_node_runs = [] | |||
| self.current_iteration_state = None | |||
| @@ -1,10 +1,8 @@ | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| class WorkflowNodeRunFailedError(Exception): | |||
| def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str): | |||
| self.node_id = node_id | |||
| self.node_type = node_type | |||
| self.node_title = node_title | |||
| def __init__(self, node_instance: BaseNode, error: str): | |||
| self.node_instance = node_instance | |||
| self.error = error | |||
| super().__init__(f"Node {node_title} run failed: {error}") | |||
| super().__init__(f"Node {node_instance.node_data.title} run failed: {error}") | |||
| @@ -0,0 +1,31 @@ | |||
| from abc import ABC, abstractmethod | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.graph_engine.entities.run_condition import RunCondition | |||
| from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState | |||
| class RunConditionHandler(ABC): | |||
| def __init__(self, | |||
| init_params: GraphInitParams, | |||
| graph: Graph, | |||
| condition: RunCondition): | |||
| self.init_params = init_params | |||
| self.graph = graph | |||
| self.condition = condition | |||
| @abstractmethod | |||
| def check(self, | |||
| graph_runtime_state: GraphRuntimeState, | |||
| previous_route_node_state: RouteNodeState | |||
| ) -> bool: | |||
| """ | |||
| Check if the condition can be executed | |||
| :param graph_runtime_state: graph runtime state | |||
| :param previous_route_node_state: previous route node state | |||
| :return: bool | |||
| """ | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,28 @@ | |||
| from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState | |||
| class BranchIdentifyRunConditionHandler(RunConditionHandler): | |||
| def check(self, | |||
| graph_runtime_state: GraphRuntimeState, | |||
| previous_route_node_state: RouteNodeState) -> bool: | |||
| """ | |||
| Check if the condition can be executed | |||
| :param graph_runtime_state: graph runtime state | |||
| :param previous_route_node_state: previous route node state | |||
| :return: bool | |||
| """ | |||
| if not self.condition.branch_identify: | |||
| raise Exception("Branch identify is required") | |||
| run_result = previous_route_node_state.node_run_result | |||
| if not run_result: | |||
| return False | |||
| if not run_result.edge_source_handle: | |||
| return False | |||
| return self.condition.branch_identify == run_result.edge_source_handle | |||
| @@ -0,0 +1,32 @@ | |||
| from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState | |||
| from core.workflow.utils.condition.processor import ConditionProcessor | |||
| class ConditionRunConditionHandlerHandler(RunConditionHandler): | |||
| def check(self, | |||
| graph_runtime_state: GraphRuntimeState, | |||
| previous_route_node_state: RouteNodeState | |||
| ) -> bool: | |||
| """ | |||
| Check if the condition can be executed | |||
| :param graph_runtime_state: graph runtime state | |||
| :param previous_route_node_state: previous route node state | |||
| :return: bool | |||
| """ | |||
| if not self.condition.conditions: | |||
| return True | |||
| # process condition | |||
| condition_processor = ConditionProcessor() | |||
| input_conditions, group_result = condition_processor.process_conditions( | |||
| variable_pool=graph_runtime_state.variable_pool, | |||
| conditions=self.condition.conditions | |||
| ) | |||
| # Apply the logical operator for the current case | |||
| compare_result = all(group_result) | |||
| return compare_result | |||
| @@ -0,0 +1,35 @@ | |||
| from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler | |||
| from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler | |||
| from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.run_condition import RunCondition | |||
| class ConditionManager: | |||
| @staticmethod | |||
| def get_condition_handler( | |||
| init_params: GraphInitParams, | |||
| graph: Graph, | |||
| run_condition: RunCondition | |||
| ) -> RunConditionHandler: | |||
| """ | |||
| Get condition handler | |||
| :param init_params: init params | |||
| :param graph: graph | |||
| :param run_condition: run condition | |||
| :return: condition handler | |||
| """ | |||
| if run_condition.type == "branch_identify": | |||
| return BranchIdentifyRunConditionHandler( | |||
| init_params=init_params, | |||
| graph=graph, | |||
| condition=run_condition | |||
| ) | |||
| else: | |||
| return ConditionRunConditionHandlerHandler( | |||
| init_params=init_params, | |||
| graph=graph, | |||
| condition=run_condition | |||
| ) | |||
| @@ -0,0 +1,163 @@ | |||
| from datetime import datetime | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, Field | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState | |||
| class GraphEngineEvent(BaseModel): | |||
| pass | |||
| ########################################### | |||
| # Graph Events | |||
| ########################################### | |||
| class BaseGraphEvent(GraphEngineEvent): | |||
| pass | |||
| class GraphRunStartedEvent(BaseGraphEvent): | |||
| pass | |||
| class GraphRunSucceededEvent(BaseGraphEvent): | |||
| outputs: Optional[dict[str, Any]] = None | |||
| """outputs""" | |||
| class GraphRunFailedEvent(BaseGraphEvent): | |||
| error: str = Field(..., description="failed reason") | |||
| ########################################### | |||
| # Node Events | |||
| ########################################### | |||
| class BaseNodeEvent(GraphEngineEvent): | |||
| id: str = Field(..., description="node execution id") | |||
| node_id: str = Field(..., description="node id") | |||
| node_type: NodeType = Field(..., description="node type") | |||
| node_data: BaseNodeData = Field(..., description="node data") | |||
| route_node_state: RouteNodeState = Field(..., description="route node state") | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| class NodeRunStartedEvent(BaseNodeEvent): | |||
| predecessor_node_id: Optional[str] = None | |||
| """predecessor node id""" | |||
| class NodeRunStreamChunkEvent(BaseNodeEvent): | |||
| chunk_content: str = Field(..., description="chunk content") | |||
| from_variable_selector: Optional[list[str]] = None | |||
| """from variable selector""" | |||
| class NodeRunRetrieverResourceEvent(BaseNodeEvent): | |||
| retriever_resources: list[dict] = Field(..., description="retriever resources") | |||
| context: str = Field(..., description="context") | |||
| class NodeRunSucceededEvent(BaseNodeEvent): | |||
| pass | |||
| class NodeRunFailedEvent(BaseNodeEvent): | |||
| error: str = Field(..., description="error") | |||
| ########################################### | |||
| # Parallel Branch Events | |||
| ########################################### | |||
| class BaseParallelBranchEvent(GraphEngineEvent): | |||
| parallel_id: str = Field(..., description="parallel id") | |||
| """parallel id""" | |||
| parallel_start_node_id: str = Field(..., description="parallel start node id") | |||
| """parallel start node id""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| class ParallelBranchRunStartedEvent(BaseParallelBranchEvent): | |||
| pass | |||
| class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent): | |||
| pass | |||
| class ParallelBranchRunFailedEvent(BaseParallelBranchEvent): | |||
| error: str = Field(..., description="failed reason") | |||
| ########################################### | |||
| # Iteration Events | |||
| ########################################### | |||
| class BaseIterationEvent(GraphEngineEvent): | |||
| iteration_id: str = Field(..., description="iteration node execution id") | |||
| iteration_node_id: str = Field(..., description="iteration node id") | |||
| iteration_node_type: NodeType = Field(..., description="node type, iteration or loop") | |||
| iteration_node_data: BaseNodeData = Field(..., description="node data") | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| class IterationRunStartedEvent(BaseIterationEvent): | |||
| start_at: datetime = Field(..., description="start at") | |||
| inputs: Optional[dict[str, Any]] = None | |||
| metadata: Optional[dict[str, Any]] = None | |||
| predecessor_node_id: Optional[str] = None | |||
| class IterationRunNextEvent(BaseIterationEvent): | |||
| index: int = Field(..., description="index") | |||
| pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output") | |||
| class IterationRunSucceededEvent(BaseIterationEvent): | |||
| start_at: datetime = Field(..., description="start at") | |||
| inputs: Optional[dict[str, Any]] = None | |||
| outputs: Optional[dict[str, Any]] = None | |||
| metadata: Optional[dict[str, Any]] = None | |||
| steps: int = 0 | |||
| class IterationRunFailedEvent(BaseIterationEvent): | |||
| start_at: datetime = Field(..., description="start at") | |||
| inputs: Optional[dict[str, Any]] = None | |||
| outputs: Optional[dict[str, Any]] = None | |||
| metadata: Optional[dict[str, Any]] = None | |||
| steps: int = 0 | |||
| error: str = Field(..., description="failed reason") | |||
| InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | |||
| @@ -0,0 +1,692 @@ | |||
| import uuid | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional, cast | |||
| from pydantic import BaseModel, Field | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.graph_engine.entities.run_condition import RunCondition | |||
| from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter | |||
| from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute | |||
| from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter | |||
| from core.workflow.nodes.end.entities import EndStreamParam | |||
| class GraphEdge(BaseModel): | |||
| source_node_id: str = Field(..., description="source node id") | |||
| target_node_id: str = Field(..., description="target node id") | |||
| run_condition: Optional[RunCondition] = None | |||
| """run condition""" | |||
| class GraphParallel(BaseModel): | |||
| id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id") | |||
| start_from_node_id: str = Field(..., description="start from node id") | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id""" | |||
| end_to_node_id: Optional[str] = None | |||
| """end to node id""" | |||
| class Graph(BaseModel): | |||
| root_node_id: str = Field(..., description="root node id of the graph") | |||
| node_ids: list[str] = Field(default_factory=list, description="graph node ids") | |||
| node_id_config_mapping: dict[str, dict] = Field( | |||
| default_factory=list, | |||
| description="node configs mapping (node id: node config)" | |||
| ) | |||
| edge_mapping: dict[str, list[GraphEdge]] = Field( | |||
| default_factory=dict, | |||
| description="graph edge mapping (source node id: edges)" | |||
| ) | |||
| reverse_edge_mapping: dict[str, list[GraphEdge]] = Field( | |||
| default_factory=dict, | |||
| description="reverse graph edge mapping (target node id: edges)" | |||
| ) | |||
| parallel_mapping: dict[str, GraphParallel] = Field( | |||
| default_factory=dict, | |||
| description="graph parallel mapping (parallel id: parallel)" | |||
| ) | |||
| node_parallel_mapping: dict[str, str] = Field( | |||
| default_factory=dict, | |||
| description="graph node parallel mapping (node id: parallel id)" | |||
| ) | |||
| answer_stream_generate_routes: AnswerStreamGenerateRoute = Field( | |||
| ..., | |||
| description="answer stream generate routes" | |||
| ) | |||
| end_stream_param: EndStreamParam = Field( | |||
| ..., | |||
| description="end stream param" | |||
| ) | |||
| @classmethod | |||
| def init(cls, | |||
| graph_config: Mapping[str, Any], | |||
| root_node_id: Optional[str] = None) -> "Graph": | |||
| """ | |||
| Init graph | |||
| :param graph_config: graph config | |||
| :param root_node_id: root node id | |||
| :return: graph | |||
| """ | |||
| # edge configs | |||
| edge_configs = graph_config.get('edges') | |||
| if edge_configs is None: | |||
| edge_configs = [] | |||
| edge_configs = cast(list, edge_configs) | |||
| # reorganize edges mapping | |||
| edge_mapping: dict[str, list[GraphEdge]] = {} | |||
| reverse_edge_mapping: dict[str, list[GraphEdge]] = {} | |||
| target_edge_ids = set() | |||
| for edge_config in edge_configs: | |||
| source_node_id = edge_config.get('source') | |||
| if not source_node_id: | |||
| continue | |||
| if source_node_id not in edge_mapping: | |||
| edge_mapping[source_node_id] = [] | |||
| target_node_id = edge_config.get('target') | |||
| if not target_node_id: | |||
| continue | |||
| if target_node_id not in reverse_edge_mapping: | |||
| reverse_edge_mapping[target_node_id] = [] | |||
| # is target node id in source node id edge mapping | |||
| if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]): | |||
| continue | |||
| target_edge_ids.add(target_node_id) | |||
| # parse run condition | |||
| run_condition = None | |||
| if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source': | |||
| run_condition = RunCondition( | |||
| type='branch_identify', | |||
| branch_identify=edge_config.get('sourceHandle') | |||
| ) | |||
| graph_edge = GraphEdge( | |||
| source_node_id=source_node_id, | |||
| target_node_id=target_node_id, | |||
| run_condition=run_condition | |||
| ) | |||
| edge_mapping[source_node_id].append(graph_edge) | |||
| reverse_edge_mapping[target_node_id].append(graph_edge) | |||
| # node configs | |||
| node_configs = graph_config.get('nodes') | |||
| if not node_configs: | |||
| raise ValueError("Graph must have at least one node") | |||
| node_configs = cast(list, node_configs) | |||
| # fetch nodes that have no predecessor node | |||
| root_node_configs = [] | |||
| all_node_id_config_mapping: dict[str, dict] = {} | |||
| for node_config in node_configs: | |||
| node_id = node_config.get('id') | |||
| if not node_id: | |||
| continue | |||
| if node_id not in target_edge_ids: | |||
| root_node_configs.append(node_config) | |||
| all_node_id_config_mapping[node_id] = node_config | |||
| root_node_ids = [node_config.get('id') for node_config in root_node_configs] | |||
| # fetch root node | |||
| if not root_node_id: | |||
| # if no root node id, use the START type node as root node | |||
| root_node_id = next((node_config.get("id") for node_config in root_node_configs | |||
| if node_config.get('data', {}).get('type', '') == NodeType.START.value), None) | |||
| if not root_node_id or root_node_id not in root_node_ids: | |||
| raise ValueError(f"Root node id {root_node_id} not found in the graph") | |||
| # Check whether it is connected to the previous node | |||
| cls._check_connected_to_previous_node( | |||
| route=[root_node_id], | |||
| edge_mapping=edge_mapping | |||
| ) | |||
| # fetch all node ids from root node | |||
| node_ids = [root_node_id] | |||
| cls._recursively_add_node_ids( | |||
| node_ids=node_ids, | |||
| edge_mapping=edge_mapping, | |||
| node_id=root_node_id | |||
| ) | |||
| node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids} | |||
| # init parallel mapping | |||
| parallel_mapping: dict[str, GraphParallel] = {} | |||
| node_parallel_mapping: dict[str, str] = {} | |||
| cls._recursively_add_parallels( | |||
| edge_mapping=edge_mapping, | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| start_node_id=root_node_id, | |||
| parallel_mapping=parallel_mapping, | |||
| node_parallel_mapping=node_parallel_mapping | |||
| ) | |||
| # Check if it exceeds N layers of parallel | |||
| for parallel in parallel_mapping.values(): | |||
| if parallel.parent_parallel_id: | |||
| cls._check_exceed_parallel_limit( | |||
| parallel_mapping=parallel_mapping, | |||
| level_limit=3, | |||
| parent_parallel_id=parallel.parent_parallel_id | |||
| ) | |||
| # init answer stream generate routes | |||
| answer_stream_generate_routes = AnswerStreamGeneratorRouter.init( | |||
| node_id_config_mapping=node_id_config_mapping, | |||
| reverse_edge_mapping=reverse_edge_mapping | |||
| ) | |||
| # init end stream param | |||
| end_stream_param = EndStreamGeneratorRouter.init( | |||
| node_id_config_mapping=node_id_config_mapping, | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| node_parallel_mapping=node_parallel_mapping | |||
| ) | |||
| # init graph | |||
| graph = cls( | |||
| root_node_id=root_node_id, | |||
| node_ids=node_ids, | |||
| node_id_config_mapping=node_id_config_mapping, | |||
| edge_mapping=edge_mapping, | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| parallel_mapping=parallel_mapping, | |||
| node_parallel_mapping=node_parallel_mapping, | |||
| answer_stream_generate_routes=answer_stream_generate_routes, | |||
| end_stream_param=end_stream_param | |||
| ) | |||
| return graph | |||
| def add_extra_edge(self, source_node_id: str, | |||
| target_node_id: str, | |||
| run_condition: Optional[RunCondition] = None) -> None: | |||
| """ | |||
| Add extra edge to the graph | |||
| :param source_node_id: source node id | |||
| :param target_node_id: target node id | |||
| :param run_condition: run condition | |||
| """ | |||
| if source_node_id not in self.node_ids or target_node_id not in self.node_ids: | |||
| return | |||
| if source_node_id not in self.edge_mapping: | |||
| self.edge_mapping[source_node_id] = [] | |||
| if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]: | |||
| return | |||
| graph_edge = GraphEdge( | |||
| source_node_id=source_node_id, | |||
| target_node_id=target_node_id, | |||
| run_condition=run_condition | |||
| ) | |||
| self.edge_mapping[source_node_id].append(graph_edge) | |||
| def get_leaf_node_ids(self) -> list[str]: | |||
| """ | |||
| Get leaf node ids of the graph | |||
| :return: leaf node ids | |||
| """ | |||
| leaf_node_ids = [] | |||
| for node_id in self.node_ids: | |||
| if node_id not in self.edge_mapping: | |||
| leaf_node_ids.append(node_id) | |||
| elif (len(self.edge_mapping[node_id]) == 1 | |||
| and self.edge_mapping[node_id][0].target_node_id == self.root_node_id): | |||
| leaf_node_ids.append(node_id) | |||
| return leaf_node_ids | |||
| @classmethod | |||
| def _recursively_add_node_ids(cls, | |||
| node_ids: list[str], | |||
| edge_mapping: dict[str, list[GraphEdge]], | |||
| node_id: str) -> None: | |||
| """ | |||
| Recursively add node ids | |||
| :param node_ids: node ids | |||
| :param edge_mapping: edge mapping | |||
| :param node_id: node id | |||
| """ | |||
| for graph_edge in edge_mapping.get(node_id, []): | |||
| if graph_edge.target_node_id in node_ids: | |||
| continue | |||
| node_ids.append(graph_edge.target_node_id) | |||
| cls._recursively_add_node_ids( | |||
| node_ids=node_ids, | |||
| edge_mapping=edge_mapping, | |||
| node_id=graph_edge.target_node_id | |||
| ) | |||
| @classmethod | |||
| def _check_connected_to_previous_node( | |||
| cls, | |||
| route: list[str], | |||
| edge_mapping: dict[str, list[GraphEdge]] | |||
| ) -> None: | |||
| """ | |||
| Check whether it is connected to the previous node | |||
| """ | |||
| last_node_id = route[-1] | |||
| for graph_edge in edge_mapping.get(last_node_id, []): | |||
| if not graph_edge.target_node_id: | |||
| continue | |||
| if graph_edge.target_node_id in route: | |||
| raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.") | |||
| new_route = route[:] | |||
| new_route.append(graph_edge.target_node_id) | |||
| cls._check_connected_to_previous_node( | |||
| route=new_route, | |||
| edge_mapping=edge_mapping, | |||
| ) | |||
| @classmethod | |||
| def _recursively_add_parallels( | |||
| cls, | |||
| edge_mapping: dict[str, list[GraphEdge]], | |||
| reverse_edge_mapping: dict[str, list[GraphEdge]], | |||
| start_node_id: str, | |||
| parallel_mapping: dict[str, GraphParallel], | |||
| node_parallel_mapping: dict[str, str], | |||
| parent_parallel: Optional[GraphParallel] = None | |||
| ) -> None: | |||
| """ | |||
| Recursively add parallel ids | |||
| :param edge_mapping: edge mapping | |||
| :param start_node_id: start from node id | |||
| :param parallel_mapping: parallel mapping | |||
| :param node_parallel_mapping: node parallel mapping | |||
| :param parent_parallel: parent parallel | |||
| """ | |||
| target_node_edges = edge_mapping.get(start_node_id, []) | |||
| parallel = None | |||
| if len(target_node_edges) > 1: | |||
| # fetch all node ids in current parallels | |||
| parallel_branch_node_ids = [] | |||
| condition_edge_mappings = {} | |||
| for graph_edge in target_node_edges: | |||
| if graph_edge.run_condition is None: | |||
| parallel_branch_node_ids.append(graph_edge.target_node_id) | |||
| else: | |||
| condition_hash = graph_edge.run_condition.hash | |||
| if not condition_hash in condition_edge_mappings: | |||
| condition_edge_mappings[condition_hash] = [] | |||
| condition_edge_mappings[condition_hash].append(graph_edge) | |||
| for _, graph_edges in condition_edge_mappings.items(): | |||
| if len(graph_edges) > 1: | |||
| for graph_edge in graph_edges: | |||
| parallel_branch_node_ids.append(graph_edge.target_node_id) | |||
| # any target node id in node_parallel_mapping | |||
| if parallel_branch_node_ids: | |||
| parent_parallel_id = parent_parallel.id if parent_parallel else None | |||
| parallel = GraphParallel( | |||
| start_from_node_id=start_node_id, | |||
| parent_parallel_id=parent_parallel.id if parent_parallel else None, | |||
| parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None | |||
| ) | |||
| parallel_mapping[parallel.id] = parallel | |||
| in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( | |||
| edge_mapping=edge_mapping, | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| parallel_branch_node_ids=parallel_branch_node_ids | |||
| ) | |||
| # collect all branches node ids | |||
| parallel_node_ids = [] | |||
| for _, node_ids in in_branch_node_ids.items(): | |||
| for node_id in node_ids: | |||
| in_parent_parallel = True | |||
| if parent_parallel_id: | |||
| in_parent_parallel = False | |||
| for parallel_node_id, parallel_id in node_parallel_mapping.items(): | |||
| if parallel_id == parent_parallel_id and parallel_node_id == node_id: | |||
| in_parent_parallel = True | |||
| break | |||
| if in_parent_parallel: | |||
| parallel_node_ids.append(node_id) | |||
| node_parallel_mapping[node_id] = parallel.id | |||
| outside_parallel_target_node_ids = set() | |||
| for node_id in parallel_node_ids: | |||
| if node_id == parallel.start_from_node_id: | |||
| continue | |||
| node_edges = edge_mapping.get(node_id) | |||
| if not node_edges: | |||
| continue | |||
| if len(node_edges) > 1: | |||
| continue | |||
| target_node_id = node_edges[0].target_node_id | |||
| if target_node_id in parallel_node_ids: | |||
| continue | |||
| if parent_parallel_id: | |||
| parent_parallel = parallel_mapping.get(parent_parallel_id) | |||
| if not parent_parallel: | |||
| continue | |||
| if ( | |||
| (node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id) | |||
| or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id) | |||
| or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) | |||
| ): | |||
| outside_parallel_target_node_ids.add(target_node_id) | |||
| if len(outside_parallel_target_node_ids) == 1: | |||
| if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id: | |||
| parallel.end_to_node_id = None | |||
| else: | |||
| parallel.end_to_node_id = outside_parallel_target_node_ids.pop() | |||
| for graph_edge in target_node_edges: | |||
| current_parallel = None | |||
| if parallel: | |||
| current_parallel = parallel | |||
| elif parent_parallel: | |||
| if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id): | |||
| current_parallel = parent_parallel | |||
| else: | |||
| # fetch parent parallel's parent parallel | |||
| parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id | |||
| if parent_parallel_parent_parallel_id: | |||
| parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) | |||
| if ( | |||
| parent_parallel_parent_parallel | |||
| and ( | |||
| not parent_parallel_parent_parallel.end_to_node_id | |||
| or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id) | |||
| ) | |||
| ): | |||
| current_parallel = parent_parallel_parent_parallel | |||
| cls._recursively_add_parallels( | |||
| edge_mapping=edge_mapping, | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| start_node_id=graph_edge.target_node_id, | |||
| parallel_mapping=parallel_mapping, | |||
| node_parallel_mapping=node_parallel_mapping, | |||
| parent_parallel=current_parallel | |||
| ) | |||
| @classmethod | |||
| def _check_exceed_parallel_limit( | |||
| cls, | |||
| parallel_mapping: dict[str, GraphParallel], | |||
| level_limit: int, | |||
| parent_parallel_id: str, | |||
| current_level: int = 1 | |||
| ) -> None: | |||
| """ | |||
| Check if it exceeds N layers of parallel | |||
| """ | |||
| parent_parallel = parallel_mapping.get(parent_parallel_id) | |||
| if not parent_parallel: | |||
| return | |||
| current_level += 1 | |||
| if current_level > level_limit: | |||
| raise ValueError(f"Exceeds {level_limit} layers of parallel") | |||
| if parent_parallel.parent_parallel_id: | |||
| cls._check_exceed_parallel_limit( | |||
| parallel_mapping=parallel_mapping, | |||
| level_limit=level_limit, | |||
| parent_parallel_id=parent_parallel.parent_parallel_id, | |||
| current_level=current_level | |||
| ) | |||
| @classmethod | |||
| def _recursively_add_parallel_node_ids(cls, | |||
| branch_node_ids: list[str], | |||
| edge_mapping: dict[str, list[GraphEdge]], | |||
| merge_node_id: str, | |||
| start_node_id: str) -> None: | |||
| """ | |||
| Recursively add node ids | |||
| :param branch_node_ids: in branch node ids | |||
| :param edge_mapping: edge mapping | |||
| :param merge_node_id: merge node id | |||
| :param start_node_id: start node id | |||
| """ | |||
| for graph_edge in edge_mapping.get(start_node_id, []): | |||
| if (graph_edge.target_node_id != merge_node_id | |||
| and graph_edge.target_node_id not in branch_node_ids): | |||
| branch_node_ids.append(graph_edge.target_node_id) | |||
| cls._recursively_add_parallel_node_ids( | |||
| branch_node_ids=branch_node_ids, | |||
| edge_mapping=edge_mapping, | |||
| merge_node_id=merge_node_id, | |||
| start_node_id=graph_edge.target_node_id | |||
| ) | |||
| @classmethod | |||
| def _fetch_all_node_ids_in_parallels(cls, | |||
| edge_mapping: dict[str, list[GraphEdge]], | |||
| reverse_edge_mapping: dict[str, list[GraphEdge]], | |||
| parallel_branch_node_ids: list[str]) -> dict[str, list[str]]: | |||
| """ | |||
| Fetch all node ids in parallels | |||
| """ | |||
| routes_node_ids: dict[str, list[str]] = {} | |||
| for parallel_branch_node_id in parallel_branch_node_ids: | |||
| routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id] | |||
| # fetch routes node ids | |||
| cls._recursively_fetch_routes( | |||
| edge_mapping=edge_mapping, | |||
| start_node_id=parallel_branch_node_id, | |||
| routes_node_ids=routes_node_ids[parallel_branch_node_id] | |||
| ) | |||
| # fetch leaf node ids from routes node ids | |||
| leaf_node_ids: dict[str, list[str]] = {} | |||
| merge_branch_node_ids: dict[str, list[str]] = {} | |||
| for branch_node_id, node_ids in routes_node_ids.items(): | |||
| for node_id in node_ids: | |||
| if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0: | |||
| if branch_node_id not in leaf_node_ids: | |||
| leaf_node_ids[branch_node_id] = [] | |||
| leaf_node_ids[branch_node_id].append(node_id) | |||
| for branch_node_id2, inner_route2 in routes_node_ids.items(): | |||
| if ( | |||
| branch_node_id != branch_node_id2 | |||
| and node_id in inner_route2 | |||
| and len(reverse_edge_mapping.get(node_id, [])) > 1 | |||
| and cls._is_node_in_routes( | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| start_node_id=node_id, | |||
| routes_node_ids=routes_node_ids | |||
| ) | |||
| ): | |||
| if node_id not in merge_branch_node_ids: | |||
| merge_branch_node_ids[node_id] = [] | |||
| if branch_node_id2 not in merge_branch_node_ids[node_id]: | |||
| merge_branch_node_ids[node_id].append(branch_node_id2) | |||
| # sorted merge_branch_node_ids by branch_node_ids length desc | |||
| merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True)) | |||
| duplicate_end_node_ids = {} | |||
| for node_id, branch_node_ids in merge_branch_node_ids.items(): | |||
| for node_id2, branch_node_ids2 in merge_branch_node_ids.items(): | |||
| if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2): | |||
| if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids: | |||
| duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids | |||
| for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): | |||
| # check which node is after | |||
| if cls._is_node2_after_node1( | |||
| node1_id=node_id, | |||
| node2_id=node_id2, | |||
| edge_mapping=edge_mapping | |||
| ): | |||
| if node_id in merge_branch_node_ids: | |||
| del merge_branch_node_ids[node_id2] | |||
| elif cls._is_node2_after_node1( | |||
| node1_id=node_id2, | |||
| node2_id=node_id, | |||
| edge_mapping=edge_mapping | |||
| ): | |||
| if node_id2 in merge_branch_node_ids: | |||
| del merge_branch_node_ids[node_id] | |||
| branches_merge_node_ids: dict[str, str] = {} | |||
| for node_id, branch_node_ids in merge_branch_node_ids.items(): | |||
| if len(branch_node_ids) <= 1: | |||
| continue | |||
| for branch_node_id in branch_node_ids: | |||
| if branch_node_id in branches_merge_node_ids: | |||
| continue | |||
| branches_merge_node_ids[branch_node_id] = node_id | |||
| in_branch_node_ids: dict[str, list[str]] = {} | |||
| for branch_node_id, node_ids in routes_node_ids.items(): | |||
| in_branch_node_ids[branch_node_id] = [] | |||
| if branch_node_id not in branches_merge_node_ids: | |||
| # all node ids in current branch is in this thread | |||
| in_branch_node_ids[branch_node_id].append(branch_node_id) | |||
| in_branch_node_ids[branch_node_id].extend(node_ids) | |||
| else: | |||
| merge_node_id = branches_merge_node_ids[branch_node_id] | |||
| if merge_node_id != branch_node_id: | |||
| in_branch_node_ids[branch_node_id].append(branch_node_id) | |||
| # fetch all node ids from branch_node_id and merge_node_id | |||
| cls._recursively_add_parallel_node_ids( | |||
| branch_node_ids=in_branch_node_ids[branch_node_id], | |||
| edge_mapping=edge_mapping, | |||
| merge_node_id=merge_node_id, | |||
| start_node_id=branch_node_id | |||
| ) | |||
| return in_branch_node_ids | |||
| @classmethod | |||
| def _recursively_fetch_routes(cls, | |||
| edge_mapping: dict[str, list[GraphEdge]], | |||
| start_node_id: str, | |||
| routes_node_ids: list[str]) -> None: | |||
| """ | |||
| Recursively fetch route | |||
| """ | |||
| if start_node_id not in edge_mapping: | |||
| return | |||
| for graph_edge in edge_mapping[start_node_id]: | |||
| # find next node ids | |||
| if graph_edge.target_node_id not in routes_node_ids: | |||
| routes_node_ids.append(graph_edge.target_node_id) | |||
| cls._recursively_fetch_routes( | |||
| edge_mapping=edge_mapping, | |||
| start_node_id=graph_edge.target_node_id, | |||
| routes_node_ids=routes_node_ids | |||
| ) | |||
| @classmethod | |||
| def _is_node_in_routes(cls, | |||
| reverse_edge_mapping: dict[str, list[GraphEdge]], | |||
| start_node_id: str, | |||
| routes_node_ids: dict[str, list[str]]) -> bool: | |||
| """ | |||
| Recursively check if the node is in the routes | |||
| """ | |||
| if start_node_id not in reverse_edge_mapping: | |||
| return False | |||
| all_routes_node_ids = set() | |||
| parallel_start_node_ids: dict[str, list[str]] = {} | |||
| for branch_node_id, node_ids in routes_node_ids.items(): | |||
| for node_id in node_ids: | |||
| all_routes_node_ids.add(node_id) | |||
| if branch_node_id in reverse_edge_mapping: | |||
| for graph_edge in reverse_edge_mapping[branch_node_id]: | |||
| if graph_edge.source_node_id not in parallel_start_node_ids: | |||
| parallel_start_node_ids[graph_edge.source_node_id] = [] | |||
| parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) | |||
| parallel_start_node_id = None | |||
| for p_start_node_id, branch_node_ids in parallel_start_node_ids.items(): | |||
| if set(branch_node_ids) == set(routes_node_ids.keys()): | |||
| parallel_start_node_id = p_start_node_id | |||
| return True | |||
| if not parallel_start_node_id: | |||
| raise Exception("Parallel start node id not found") | |||
| for graph_edge in reverse_edge_mapping[start_node_id]: | |||
| if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id: | |||
| return False | |||
| return True | |||
| @classmethod | |||
| def _is_node2_after_node1( | |||
| cls, | |||
| node1_id: str, | |||
| node2_id: str, | |||
| edge_mapping: dict[str, list[GraphEdge]] | |||
| ) -> bool: | |||
| """ | |||
| is node2 after node1 | |||
| """ | |||
| if node1_id not in edge_mapping: | |||
| return False | |||
| for graph_edge in edge_mapping[node1_id]: | |||
| if graph_edge.target_node_id == node2_id: | |||
| return True | |||
| if cls._is_node2_after_node1( | |||
| node1_id=graph_edge.target_node_id, | |||
| node2_id=node2_id, | |||
| edge_mapping=edge_mapping | |||
| ): | |||
| return True | |||
| return False | |||
| @@ -0,0 +1,21 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any | |||
| from pydantic import BaseModel, Field | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import UserFrom | |||
| from models.workflow import WorkflowType | |||
| class GraphInitParams(BaseModel): | |||
| # init params | |||
| tenant_id: str = Field(..., description="tenant / workspace id") | |||
| app_id: str = Field(..., description="app id") | |||
| workflow_type: WorkflowType = Field(..., description="workflow type") | |||
| workflow_id: str = Field(..., description="workflow id") | |||
| graph_config: Mapping[str, Any] = Field(..., description="graph config") | |||
| user_id: str = Field(..., description="user id") | |||
| user_from: UserFrom = Field(..., description="user from, account or end-user") | |||
| invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger") | |||
| call_depth: int = Field(..., description="call depth") | |||
| @@ -0,0 +1,27 @@ | |||
| from typing import Any | |||
| from pydantic import BaseModel, Field | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState | |||
| class GraphRuntimeState(BaseModel): | |||
| variable_pool: VariablePool = Field(..., description="variable pool") | |||
| """variable pool""" | |||
| start_at: float = Field(..., description="start time") | |||
| """start time""" | |||
| total_tokens: int = 0 | |||
| """total tokens""" | |||
| llm_usage: LLMUsage = LLMUsage.empty_usage() | |||
| """llm usage info""" | |||
| outputs: dict[str, Any] = {} | |||
| """outputs""" | |||
| node_run_steps: int = 0 | |||
| """node run steps""" | |||
| node_run_state: RuntimeRouteState = RuntimeRouteState() | |||
| """node run state""" | |||
| @@ -0,0 +1,13 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel | |||
| from core.workflow.graph_engine.entities.graph import GraphParallel | |||
| class NextGraphNode(BaseModel): | |||
| node_id: str | |||
| """next node id""" | |||
| parallel: Optional[GraphParallel] = None | |||
| """parallel""" | |||
| @@ -0,0 +1,21 @@ | |||
| import hashlib | |||
| from typing import Literal, Optional | |||
| from pydantic import BaseModel | |||
| from core.workflow.utils.condition.entities import Condition | |||
| class RunCondition(BaseModel): | |||
| type: Literal["branch_identify", "condition"] | |||
| """condition type""" | |||
| branch_identify: Optional[str] = None | |||
| """branch identify like: sourceHandle, required when type is branch_identify""" | |||
| conditions: Optional[list[Condition]] = None | |||
| """conditions to run the node, required when type is condition""" | |||
| @property | |||
| def hash(self) -> str: | |||
| return hashlib.sha256(self.model_dump_json().encode()).hexdigest() | |||
| @@ -0,0 +1,111 @@ | |||
| import uuid | |||
| from datetime import datetime, timezone | |||
| from enum import Enum | |||
| from typing import Optional | |||
| from pydantic import BaseModel, Field | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class RouteNodeState(BaseModel): | |||
| class Status(Enum): | |||
| RUNNING = "running" | |||
| SUCCESS = "success" | |||
| FAILED = "failed" | |||
| PAUSED = "paused" | |||
| id: str = Field(default_factory=lambda: str(uuid.uuid4())) | |||
| """node state id""" | |||
| node_id: str | |||
| """node id""" | |||
| node_run_result: Optional[NodeRunResult] = None | |||
| """node run result""" | |||
| status: Status = Status.RUNNING | |||
| """node status""" | |||
| start_at: datetime | |||
| """start time""" | |||
| paused_at: Optional[datetime] = None | |||
| """paused time""" | |||
| finished_at: Optional[datetime] = None | |||
| """finished time""" | |||
| failed_reason: Optional[str] = None | |||
| """failed reason""" | |||
| paused_by: Optional[str] = None | |||
| """paused by""" | |||
| index: int = 1 | |||
| def set_finished(self, run_result: NodeRunResult) -> None: | |||
| """ | |||
| Node finished | |||
| :param run_result: run result | |||
| """ | |||
| if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: | |||
| raise Exception(f"Route state {self.id} already finished") | |||
| if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: | |||
| self.status = RouteNodeState.Status.SUCCESS | |||
| elif run_result.status == WorkflowNodeExecutionStatus.FAILED: | |||
| self.status = RouteNodeState.Status.FAILED | |||
| self.failed_reason = run_result.error | |||
| else: | |||
| raise Exception(f"Invalid route status {run_result.status}") | |||
| self.node_run_result = run_result | |||
| self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| class RuntimeRouteState(BaseModel): | |||
| routes: dict[str, list[str]] = Field( | |||
| default_factory=dict, | |||
| description="graph state routes (source_node_state_id: target_node_state_id)" | |||
| ) | |||
| node_state_mapping: dict[str, RouteNodeState] = Field( | |||
| default_factory=dict, | |||
| description="node state mapping (route_node_state_id: route_node_state)" | |||
| ) | |||
| def create_node_state(self, node_id: str) -> RouteNodeState: | |||
| """ | |||
| Create node state | |||
| :param node_id: node id | |||
| """ | |||
| state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) | |||
| self.node_state_mapping[state.id] = state | |||
| return state | |||
| def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None: | |||
| """ | |||
| Add route to the graph state | |||
| :param source_node_state_id: source node state id | |||
| :param target_node_state_id: target node state id | |||
| """ | |||
| if source_node_state_id not in self.routes: | |||
| self.routes[source_node_state_id] = [] | |||
| self.routes[source_node_state_id].append(target_node_state_id) | |||
| def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \ | |||
| -> list[RouteNodeState]: | |||
| """ | |||
| Get routes with node state by source node id | |||
| :param source_node_state_id: source node state id | |||
| :return: routes with node state | |||
| """ | |||
| return [self.node_state_mapping[target_state_id] | |||
| for target_state_id in self.routes.get(source_node_state_id, [])] | |||
| @@ -0,0 +1,716 @@ | |||
| import logging | |||
| import queue | |||
| import time | |||
| import uuid | |||
| from collections.abc import Generator, Mapping | |||
| from concurrent.futures import ThreadPoolExecutor, wait | |||
| from typing import Any, Optional | |||
| from flask import Flask, current_app | |||
| from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import ( | |||
| NodeRunMetadataKey, | |||
| NodeType, | |||
| UserFrom, | |||
| ) | |||
| from core.workflow.entities.variable_pool import VariablePool, VariableValue | |||
| from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| BaseIterationEvent, | |||
| GraphEngineEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunStartedEvent, | |||
| GraphRunSucceededEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunRetrieverResourceEvent, | |||
| NodeRunStartedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| NodeRunSucceededEvent, | |||
| ParallelBranchRunFailedEvent, | |||
| ParallelBranchRunStartedEvent, | |||
| ParallelBranchRunSucceededEvent, | |||
| ) | |||
| from core.workflow.graph_engine.entities.graph import Graph, GraphEdge | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState | |||
| from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor | |||
| from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent | |||
| from core.workflow.nodes.node_mapping import node_classes | |||
| from extensions.ext_database import db | |||
| from models.workflow import WorkflowNodeExecutionStatus, WorkflowType | |||
| logger = logging.getLogger(__name__) | |||
| class GraphEngineThreadPool(ThreadPoolExecutor): | |||
| def __init__(self, max_workers=None, thread_name_prefix='', | |||
| initializer=None, initargs=(), max_submit_count=100) -> None: | |||
| super().__init__(max_workers, thread_name_prefix, initializer, initargs) | |||
| self.max_submit_count = max_submit_count | |||
| self.submit_count = 0 | |||
| def submit(self, fn, *args, **kwargs): | |||
| self.submit_count += 1 | |||
| self.check_is_full() | |||
| return super().submit(fn, *args, **kwargs) | |||
| def check_is_full(self) -> None: | |||
| print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}") | |||
| if self.submit_count > self.max_submit_count: | |||
| raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") | |||
| class GraphEngine: | |||
| workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} | |||
| def __init__( | |||
| self, | |||
| tenant_id: str, | |||
| app_id: str, | |||
| workflow_type: WorkflowType, | |||
| workflow_id: str, | |||
| user_id: str, | |||
| user_from: UserFrom, | |||
| invoke_from: InvokeFrom, | |||
| call_depth: int, | |||
| graph: Graph, | |||
| graph_config: Mapping[str, Any], | |||
| variable_pool: VariablePool, | |||
| max_execution_steps: int, | |||
| max_execution_time: int, | |||
| thread_pool_id: Optional[str] = None | |||
| ) -> None: | |||
| thread_pool_max_submit_count = 100 | |||
| thread_pool_max_workers = 10 | |||
| ## init thread pool | |||
| if thread_pool_id: | |||
| if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping: | |||
| raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.") | |||
| self.thread_pool_id = thread_pool_id | |||
| self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id] | |||
| self.is_main_thread_pool = False | |||
| else: | |||
| self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count) | |||
| self.thread_pool_id = str(uuid.uuid4()) | |||
| self.is_main_thread_pool = True | |||
| GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool | |||
| self.graph = graph | |||
| self.init_params = GraphInitParams( | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| workflow_type=workflow_type, | |||
| workflow_id=workflow_id, | |||
| graph_config=graph_config, | |||
| user_id=user_id, | |||
| user_from=user_from, | |||
| invoke_from=invoke_from, | |||
| call_depth=call_depth | |||
| ) | |||
| self.graph_runtime_state = GraphRuntimeState( | |||
| variable_pool=variable_pool, | |||
| start_at=time.perf_counter() | |||
| ) | |||
| self.max_execution_steps = max_execution_steps | |||
| self.max_execution_time = max_execution_time | |||
| def run(self) -> Generator[GraphEngineEvent, None, None]: | |||
| # trigger graph run start event | |||
| yield GraphRunStartedEvent() | |||
| try: | |||
| stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor] | |||
| if self.init_params.workflow_type == WorkflowType.CHAT: | |||
| stream_processor_cls = AnswerStreamProcessor | |||
| else: | |||
| stream_processor_cls = EndStreamProcessor | |||
| stream_processor = stream_processor_cls( | |||
| graph=self.graph, | |||
| variable_pool=self.graph_runtime_state.variable_pool | |||
| ) | |||
| # run graph | |||
| generator = stream_processor.process( | |||
| self._run(start_node_id=self.graph.root_node_id) | |||
| ) | |||
| for item in generator: | |||
| try: | |||
| yield item | |||
| if isinstance(item, NodeRunFailedEvent): | |||
| yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.') | |||
| return | |||
| elif isinstance(item, NodeRunSucceededEvent): | |||
| if item.node_type == NodeType.END: | |||
| self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs | |||
| if item.route_node_state.node_run_result | |||
| and item.route_node_state.node_run_result.outputs | |||
| else {}) | |||
| elif item.node_type == NodeType.ANSWER: | |||
| if "answer" not in self.graph_runtime_state.outputs: | |||
| self.graph_runtime_state.outputs["answer"] = "" | |||
| self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "") | |||
| if item.route_node_state.node_run_result | |||
| and item.route_node_state.node_run_result.outputs | |||
| else "") | |||
| self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip() | |||
| except Exception as e: | |||
| logger.exception(f"Graph run failed: {str(e)}") | |||
| yield GraphRunFailedEvent(error=str(e)) | |||
| return | |||
| # trigger graph run success event | |||
| yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) | |||
| except GraphRunFailedError as e: | |||
| yield GraphRunFailedEvent(error=e.error) | |||
| return | |||
| except Exception as e: | |||
| logger.exception("Unknown Error when graph running") | |||
| yield GraphRunFailedEvent(error=str(e)) | |||
| raise e | |||
| finally: | |||
| if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping: | |||
| del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] | |||
| def _run( | |||
| self, | |||
| start_node_id: str, | |||
| in_parallel_id: Optional[str] = None, | |||
| parent_parallel_id: Optional[str] = None, | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| ) -> Generator[GraphEngineEvent, None, None]: | |||
| parallel_start_node_id = None | |||
| if in_parallel_id: | |||
| parallel_start_node_id = start_node_id | |||
| next_node_id = start_node_id | |||
| previous_route_node_state: Optional[RouteNodeState] = None | |||
| while True: | |||
| # max steps reached | |||
| if self.graph_runtime_state.node_run_steps > self.max_execution_steps: | |||
| raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps)) | |||
| # or max execution time reached | |||
| if self._is_timed_out( | |||
| start_at=self.graph_runtime_state.start_at, | |||
| max_execution_time=self.max_execution_time | |||
| ): | |||
| raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) | |||
| # init route node state | |||
| route_node_state = self.graph_runtime_state.node_run_state.create_node_state( | |||
| node_id=next_node_id | |||
| ) | |||
| # get node config | |||
| node_id = route_node_state.node_id | |||
| node_config = self.graph.node_id_config_mapping.get(node_id) | |||
| if not node_config: | |||
| raise GraphRunFailedError(f'Node {node_id} config not found.') | |||
| # convert to specific node | |||
| node_type = NodeType.value_of(node_config.get('data', {}).get('type')) | |||
| node_cls = node_classes.get(node_type) | |||
| if not node_cls: | |||
| raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') | |||
| previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None | |||
| # init workflow run state | |||
| node_instance = node_cls( # type: ignore | |||
| id=route_node_state.id, | |||
| config=node_config, | |||
| graph_init_params=self.init_params, | |||
| graph=self.graph, | |||
| graph_runtime_state=self.graph_runtime_state, | |||
| previous_node_id=previous_node_id, | |||
| thread_pool_id=self.thread_pool_id | |||
| ) | |||
| try: | |||
| # run node | |||
| generator = self._run_node( | |||
| node_instance=node_instance, | |||
| route_node_state=route_node_state, | |||
| parallel_id=in_parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id | |||
| ) | |||
| for item in generator: | |||
| if isinstance(item, NodeRunStartedEvent): | |||
| self.graph_runtime_state.node_run_steps += 1 | |||
| item.route_node_state.index = self.graph_runtime_state.node_run_steps | |||
| yield item | |||
| self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state | |||
| # append route | |||
| if previous_route_node_state: | |||
| self.graph_runtime_state.node_run_state.add_route( | |||
| source_node_state_id=previous_route_node_state.id, | |||
| target_node_state_id=route_node_state.id | |||
| ) | |||
| except Exception as e: | |||
| route_node_state.status = RouteNodeState.Status.FAILED | |||
| route_node_state.failed_reason = str(e) | |||
| yield NodeRunFailedEvent( | |||
| error=str(e), | |||
| id=node_instance.id, | |||
| node_id=next_node_id, | |||
| node_type=node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=in_parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id | |||
| ) | |||
| raise e | |||
| # It may not be necessary, but it is necessary. :) | |||
| if (self.graph.node_id_config_mapping[next_node_id] | |||
| .get("data", {}).get("type", "").lower() == NodeType.END.value): | |||
| break | |||
| previous_route_node_state = route_node_state | |||
| # get next node ids | |||
| edge_mappings = self.graph.edge_mapping.get(next_node_id) | |||
| if not edge_mappings: | |||
| break | |||
| if len(edge_mappings) == 1: | |||
| edge = edge_mappings[0] | |||
| if edge.run_condition: | |||
| result = ConditionManager.get_condition_handler( | |||
| init_params=self.init_params, | |||
| graph=self.graph, | |||
| run_condition=edge.run_condition, | |||
| ).check( | |||
| graph_runtime_state=self.graph_runtime_state, | |||
| previous_route_node_state=previous_route_node_state | |||
| ) | |||
| if not result: | |||
| break | |||
| next_node_id = edge.target_node_id | |||
| else: | |||
| final_node_id = None | |||
| if any(edge.run_condition for edge in edge_mappings): | |||
| # if nodes has run conditions, get node id which branch to take based on the run condition results | |||
| condition_edge_mappings = {} | |||
| for edge in edge_mappings: | |||
| if edge.run_condition: | |||
| run_condition_hash = edge.run_condition.hash | |||
| if run_condition_hash not in condition_edge_mappings: | |||
| condition_edge_mappings[run_condition_hash] = [] | |||
| condition_edge_mappings[run_condition_hash].append(edge) | |||
| for _, sub_edge_mappings in condition_edge_mappings.items(): | |||
| if len(sub_edge_mappings) == 0: | |||
| continue | |||
| edge = sub_edge_mappings[0] | |||
| result = ConditionManager.get_condition_handler( | |||
| init_params=self.init_params, | |||
| graph=self.graph, | |||
| run_condition=edge.run_condition, | |||
| ).check( | |||
| graph_runtime_state=self.graph_runtime_state, | |||
| previous_route_node_state=previous_route_node_state, | |||
| ) | |||
| if not result: | |||
| continue | |||
| if len(sub_edge_mappings) == 1: | |||
| final_node_id = edge.target_node_id | |||
| else: | |||
| parallel_generator = self._run_parallel_branches( | |||
| edge_mappings=sub_edge_mappings, | |||
| in_parallel_id=in_parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id | |||
| ) | |||
| for item in parallel_generator: | |||
| if isinstance(item, str): | |||
| final_node_id = item | |||
| else: | |||
| yield item | |||
| break | |||
| if not final_node_id: | |||
| break | |||
| next_node_id = final_node_id | |||
| else: | |||
| parallel_generator = self._run_parallel_branches( | |||
| edge_mappings=edge_mappings, | |||
| in_parallel_id=in_parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id | |||
| ) | |||
| for item in parallel_generator: | |||
| if isinstance(item, str): | |||
| final_node_id = item | |||
| else: | |||
| yield item | |||
| if not final_node_id: | |||
| break | |||
| next_node_id = final_node_id | |||
| if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id: | |||
| break | |||
| def _run_parallel_branches( | |||
| self, | |||
| edge_mappings: list[GraphEdge], | |||
| in_parallel_id: Optional[str] = None, | |||
| parallel_start_node_id: Optional[str] = None, | |||
| ) -> Generator[GraphEngineEvent | str, None, None]: | |||
| # if nodes has no run conditions, parallel run all nodes | |||
| parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) | |||
| if not parallel_id: | |||
| node_id = edge_mappings[0].target_node_id | |||
| node_config = self.graph.node_id_config_mapping.get(node_id) | |||
| if not node_config: | |||
| raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.') | |||
| node_title = node_config.get('data', {}).get('title') | |||
| raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.') | |||
| parallel = self.graph.parallel_mapping.get(parallel_id) | |||
| if not parallel: | |||
| raise GraphRunFailedError(f'Parallel {parallel_id} not found.') | |||
| # run parallel nodes, run in new thread and use queue to get results | |||
| q: queue.Queue = queue.Queue() | |||
| # Create a list to store the threads | |||
| futures = [] | |||
| # new thread | |||
| for edge in edge_mappings: | |||
| if ( | |||
| edge.target_node_id not in self.graph.node_parallel_mapping | |||
| or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id | |||
| ): | |||
| continue | |||
| futures.append( | |||
| self.thread_pool.submit(self._run_parallel_node, **{ | |||
| 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] | |||
| 'q': q, | |||
| 'parallel_id': parallel_id, | |||
| 'parallel_start_node_id': edge.target_node_id, | |||
| 'parent_parallel_id': in_parallel_id, | |||
| 'parent_parallel_start_node_id': parallel_start_node_id, | |||
| }) | |||
| ) | |||
| succeeded_count = 0 | |||
| while True: | |||
| try: | |||
| event = q.get(timeout=1) | |||
| if event is None: | |||
| break | |||
| yield event | |||
| if event.parallel_id == parallel_id: | |||
| if isinstance(event, ParallelBranchRunSucceededEvent): | |||
| succeeded_count += 1 | |||
| if succeeded_count == len(futures): | |||
| q.put(None) | |||
| continue | |||
| elif isinstance(event, ParallelBranchRunFailedEvent): | |||
| raise GraphRunFailedError(event.error) | |||
| except queue.Empty: | |||
| continue | |||
| # wait all threads | |||
| wait(futures) | |||
| # get final node id | |||
| final_node_id = parallel.end_to_node_id | |||
| if final_node_id: | |||
| yield final_node_id | |||
| def _run_parallel_node( | |||
| self, | |||
| flask_app: Flask, | |||
| q: queue.Queue, | |||
| parallel_id: str, | |||
| parallel_start_node_id: str, | |||
| parent_parallel_id: Optional[str] = None, | |||
| parent_parallel_start_node_id: Optional[str] = None, | |||
| ) -> None: | |||
| """ | |||
| Run parallel nodes | |||
| """ | |||
| with flask_app.app_context(): | |||
| try: | |||
| q.put(ParallelBranchRunStartedEvent( | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id | |||
| )) | |||
| # run node | |||
| generator = self._run( | |||
| start_node_id=parallel_start_node_id, | |||
| in_parallel_id=parallel_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id | |||
| ) | |||
| for item in generator: | |||
| q.put(item) | |||
| # trigger graph run success event | |||
| q.put(ParallelBranchRunSucceededEvent( | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id | |||
| )) | |||
| except GraphRunFailedError as e: | |||
| q.put(ParallelBranchRunFailedEvent( | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| error=e.error | |||
| )) | |||
| except Exception as e: | |||
| logger.exception("Unknown Error when generating in parallel") | |||
| q.put(ParallelBranchRunFailedEvent( | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| error=str(e) | |||
| )) | |||
| finally: | |||
| db.session.remove() | |||
| def _run_node( | |||
| self, | |||
| node_instance: BaseNode, | |||
| route_node_state: RouteNodeState, | |||
| parallel_id: Optional[str] = None, | |||
| parallel_start_node_id: Optional[str] = None, | |||
| parent_parallel_id: Optional[str] = None, | |||
| parent_parallel_start_node_id: Optional[str] = None, | |||
| ) -> Generator[GraphEngineEvent, None, None]: | |||
| """ | |||
| Run node | |||
| """ | |||
| # trigger node run start event | |||
| yield NodeRunStartedEvent( | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| predecessor_node_id=node_instance.previous_node_id, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id | |||
| ) | |||
| db.session.close() | |||
| try: | |||
| # run node | |||
| generator = node_instance.run() | |||
| for item in generator: | |||
| if isinstance(item, GraphEngineEvent): | |||
| if isinstance(item, BaseIterationEvent): | |||
| # add parallel info to iteration event | |||
| item.parallel_id = parallel_id | |||
| item.parallel_start_node_id = parallel_start_node_id | |||
| item.parent_parallel_id = parent_parallel_id | |||
| item.parent_parallel_start_node_id = parent_parallel_start_node_id | |||
| yield item | |||
| else: | |||
| if isinstance(item, RunCompletedEvent): | |||
| run_result = item.run_result | |||
| route_node_state.set_finished(run_result=run_result) | |||
| if run_result.status == WorkflowNodeExecutionStatus.FAILED: | |||
| yield NodeRunFailedEvent( | |||
| error=route_node_state.failed_reason or 'Unknown error.', | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id | |||
| ) | |||
| elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: | |||
| if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): | |||
| # plus state total_tokens | |||
| self.graph_runtime_state.total_tokens += int( | |||
| run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] | |||
| ) | |||
| if run_result.llm_usage: | |||
| # use the latest usage | |||
| self.graph_runtime_state.llm_usage += run_result.llm_usage | |||
| # append node output variables to variable pool | |||
| if run_result.outputs: | |||
| for variable_key, variable_value in run_result.outputs.items(): | |||
| # append variables to variable pool recursively | |||
| self._append_variables_recursively( | |||
| node_id=node_instance.node_id, | |||
| variable_key_list=[variable_key], | |||
| variable_value=variable_value | |||
| ) | |||
| # add parallel info to run result metadata | |||
| if parallel_id and parallel_start_node_id: | |||
| if not run_result.metadata: | |||
| run_result.metadata = {} | |||
| run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id | |||
| run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id | |||
| if parent_parallel_id and parent_parallel_start_node_id: | |||
| run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id | |||
| run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id | |||
| yield NodeRunSucceededEvent( | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id | |||
| ) | |||
| break | |||
| elif isinstance(item, RunStreamChunkEvent): | |||
| yield NodeRunStreamChunkEvent( | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| chunk_content=item.chunk_content, | |||
| from_variable_selector=item.from_variable_selector, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id | |||
| ) | |||
| elif isinstance(item, RunRetrieverResourceEvent): | |||
| yield NodeRunRetrieverResourceEvent( | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| retriever_resources=item.retriever_resources, | |||
| context=item.context, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id | |||
| ) | |||
| except GenerateTaskStoppedException: | |||
| # trigger node run failed event | |||
| route_node_state.status = RouteNodeState.Status.FAILED | |||
| route_node_state.failed_reason = "Workflow stopped." | |||
| yield NodeRunFailedEvent( | |||
| error="Workflow stopped.", | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id | |||
| ) | |||
| return | |||
| except Exception as e: | |||
| logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}") | |||
| raise e | |||
| finally: | |||
| db.session.close() | |||
| def _append_variables_recursively(self, | |||
| node_id: str, | |||
| variable_key_list: list[str], | |||
| variable_value: VariableValue): | |||
| """ | |||
| Append variables recursively | |||
| :param node_id: node id | |||
| :param variable_key_list: variable key list | |||
| :param variable_value: variable value | |||
| :return: | |||
| """ | |||
| self.graph_runtime_state.variable_pool.add( | |||
| [node_id] + variable_key_list, | |||
| variable_value | |||
| ) | |||
| # if variable_value is a dict, then recursively append variables | |||
| if isinstance(variable_value, dict): | |||
| for key, value in variable_value.items(): | |||
| # construct new key list | |||
| new_key_list = variable_key_list + [key] | |||
| self._append_variables_recursively( | |||
| node_id=node_id, | |||
| variable_key_list=new_key_list, | |||
| variable_value=value | |||
| ) | |||
| def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: | |||
| """ | |||
| Check timeout | |||
| :param start_at: start time | |||
| :param max_execution_time: max execution time | |||
| :return: | |||
| """ | |||
| return time.perf_counter() - start_at > max_execution_time | |||
| class GraphRunFailedError(Exception): | |||
| def __init__(self, error: str): | |||
| self.error = error | |||
| @@ -1,9 +1,8 @@ | |||
| from typing import cast | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, cast | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter | |||
| from core.workflow.nodes.answer.entities import ( | |||
| AnswerNodeData, | |||
| GenerateRouteChunk, | |||
| @@ -19,24 +18,26 @@ class AnswerNode(BaseNode): | |||
| _node_data_cls = AnswerNodeData | |||
| _node_type: NodeType = NodeType.ANSWER | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run node | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| node_data = self.node_data | |||
| node_data = cast(AnswerNodeData, node_data) | |||
| # generate routes | |||
| generate_routes = self.extract_generate_route_from_node_data(node_data) | |||
| generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data) | |||
| answer = '' | |||
| for part in generate_routes: | |||
| if part.type == "var": | |||
| if part.type == GenerateRouteChunk.ChunkType.VAR: | |||
| part = cast(VarGenerateRouteChunk, part) | |||
| value_selector = part.value_selector | |||
| value = variable_pool.get(value_selector) | |||
| value = self.graph_runtime_state.variable_pool.get( | |||
| value_selector | |||
| ) | |||
| if value: | |||
| answer += value.markdown | |||
| else: | |||
| @@ -51,70 +52,16 @@ class AnswerNode(BaseNode): | |||
| ) | |||
| @classmethod | |||
| def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: | |||
| """ | |||
| Extract generate route selectors | |||
| :param config: node config | |||
| :return: | |||
| """ | |||
| node_data = cls._node_data_cls(**config.get("data", {})) | |||
| node_data = cast(AnswerNodeData, node_data) | |||
| return cls.extract_generate_route_from_node_data(node_data) | |||
| @classmethod | |||
| def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: | |||
| """ | |||
| Extract generate route from node data | |||
| :param node_data: node data object | |||
| :return: | |||
| """ | |||
| variable_template_parser = VariableTemplateParser(template=node_data.answer) | |||
| variable_selectors = variable_template_parser.extract_variable_selectors() | |||
| value_selector_mapping = { | |||
| variable_selector.variable: variable_selector.value_selector | |||
| for variable_selector in variable_selectors | |||
| } | |||
| variable_keys = list(value_selector_mapping.keys()) | |||
| # format answer template | |||
| template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) | |||
| template_variable_keys = template_parser.variable_keys | |||
| # Take the intersection of variable_keys and template_variable_keys | |||
| variable_keys = list(set(variable_keys) & set(template_variable_keys)) | |||
| template = node_data.answer | |||
| for var in variable_keys: | |||
| template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') | |||
| generate_routes = [] | |||
| for part in template.split('Ω'): | |||
| if part: | |||
| if cls._is_variable(part, variable_keys): | |||
| var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') | |||
| value_selector = value_selector_mapping[var_key] | |||
| generate_routes.append(VarGenerateRouteChunk( | |||
| value_selector=value_selector | |||
| )) | |||
| else: | |||
| generate_routes.append(TextGenerateRouteChunk( | |||
| text=part | |||
| )) | |||
| return generate_routes | |||
| @classmethod | |||
| def _is_variable(cls, part, variable_keys): | |||
| cleaned_part = part.replace('{{', '').replace('}}', '') | |||
| return part.startswith('{{') and cleaned_part in variable_keys | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: AnswerNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| @@ -126,6 +73,6 @@ class AnswerNode(BaseNode): | |||
| variable_mapping = {} | |||
| for variable_selector in variable_selectors: | |||
| variable_mapping[variable_selector.variable] = variable_selector.value_selector | |||
| variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector | |||
| return variable_mapping | |||
| @@ -0,0 +1,169 @@ | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.nodes.answer.entities import ( | |||
| AnswerNodeData, | |||
| AnswerStreamGenerateRoute, | |||
| GenerateRouteChunk, | |||
| TextGenerateRouteChunk, | |||
| VarGenerateRouteChunk, | |||
| ) | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| class AnswerStreamGeneratorRouter: | |||
| @classmethod | |||
| def init(cls, | |||
| node_id_config_mapping: dict[str, dict], | |||
| reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined] | |||
| ) -> AnswerStreamGenerateRoute: | |||
| """ | |||
| Get stream generate routes. | |||
| :return: | |||
| """ | |||
| # parse stream output node value selectors of answer nodes | |||
| answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} | |||
| for answer_node_id, node_config in node_id_config_mapping.items(): | |||
| if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value: | |||
| continue | |||
| # get generate route for stream output | |||
| generate_route = cls._extract_generate_route_selectors(node_config) | |||
| answer_generate_route[answer_node_id] = generate_route | |||
| # fetch answer dependencies | |||
| answer_node_ids = list(answer_generate_route.keys()) | |||
| answer_dependencies = cls._fetch_answers_dependencies( | |||
| answer_node_ids=answer_node_ids, | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| node_id_config_mapping=node_id_config_mapping | |||
| ) | |||
| return AnswerStreamGenerateRoute( | |||
| answer_generate_route=answer_generate_route, | |||
| answer_dependencies=answer_dependencies | |||
| ) | |||
| @classmethod | |||
| def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: | |||
| """ | |||
| Extract generate route from node data | |||
| :param node_data: node data object | |||
| :return: | |||
| """ | |||
| variable_template_parser = VariableTemplateParser(template=node_data.answer) | |||
| variable_selectors = variable_template_parser.extract_variable_selectors() | |||
| value_selector_mapping = { | |||
| variable_selector.variable: variable_selector.value_selector | |||
| for variable_selector in variable_selectors | |||
| } | |||
| variable_keys = list(value_selector_mapping.keys()) | |||
| # format answer template | |||
| template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) | |||
| template_variable_keys = template_parser.variable_keys | |||
| # Take the intersection of variable_keys and template_variable_keys | |||
| variable_keys = list(set(variable_keys) & set(template_variable_keys)) | |||
| template = node_data.answer | |||
| for var in variable_keys: | |||
| template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') | |||
| generate_routes: list[GenerateRouteChunk] = [] | |||
| for part in template.split('Ω'): | |||
| if part: | |||
| if cls._is_variable(part, variable_keys): | |||
| var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') | |||
| value_selector = value_selector_mapping[var_key] | |||
| generate_routes.append(VarGenerateRouteChunk( | |||
| value_selector=value_selector | |||
| )) | |||
| else: | |||
| generate_routes.append(TextGenerateRouteChunk( | |||
| text=part | |||
| )) | |||
| return generate_routes | |||
| @classmethod | |||
| def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: | |||
| """ | |||
| Extract generate route selectors | |||
| :param config: node config | |||
| :return: | |||
| """ | |||
| node_data = AnswerNodeData(**config.get("data", {})) | |||
| return cls.extract_generate_route_from_node_data(node_data) | |||
| @classmethod | |||
| def _is_variable(cls, part, variable_keys): | |||
| cleaned_part = part.replace('{{', '').replace('}}', '') | |||
| return part.startswith('{{') and cleaned_part in variable_keys | |||
| @classmethod | |||
| def _fetch_answers_dependencies(cls, | |||
| answer_node_ids: list[str], | |||
| reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] | |||
| node_id_config_mapping: dict[str, dict] | |||
| ) -> dict[str, list[str]]: | |||
| """ | |||
| Fetch answer dependencies | |||
| :param answer_node_ids: answer node ids | |||
| :param reverse_edge_mapping: reverse edge mapping | |||
| :param node_id_config_mapping: node id config mapping | |||
| :return: | |||
| """ | |||
| answer_dependencies: dict[str, list[str]] = {} | |||
| for answer_node_id in answer_node_ids: | |||
| if answer_dependencies.get(answer_node_id) is None: | |||
| answer_dependencies[answer_node_id] = [] | |||
| cls._recursive_fetch_answer_dependencies( | |||
| current_node_id=answer_node_id, | |||
| answer_node_id=answer_node_id, | |||
| node_id_config_mapping=node_id_config_mapping, | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| answer_dependencies=answer_dependencies | |||
| ) | |||
| return answer_dependencies | |||
| @classmethod | |||
| def _recursive_fetch_answer_dependencies(cls, | |||
| current_node_id: str, | |||
| answer_node_id: str, | |||
| node_id_config_mapping: dict[str, dict], | |||
| reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] | |||
| answer_dependencies: dict[str, list[str]] | |||
| ) -> None: | |||
| """ | |||
| Recursive fetch answer dependencies | |||
| :param current_node_id: current node id | |||
| :param answer_node_id: answer node id | |||
| :param node_id_config_mapping: node id config mapping | |||
| :param reverse_edge_mapping: reverse edge mapping | |||
| :param answer_dependencies: answer dependencies | |||
| :return: | |||
| """ | |||
| reverse_edges = reverse_edge_mapping.get(current_node_id, []) | |||
| for edge in reverse_edges: | |||
| source_node_id = edge.source_node_id | |||
| source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') | |||
| if source_node_type in ( | |||
| NodeType.ANSWER.value, | |||
| NodeType.IF_ELSE.value, | |||
| NodeType.QUESTION_CLASSIFIER, | |||
| ): | |||
| answer_dependencies[answer_node_id].append(source_node_id) | |||
| else: | |||
| cls._recursive_fetch_answer_dependencies( | |||
| current_node_id=source_node_id, | |||
| answer_node_id=answer_node_id, | |||
| node_id_config_mapping=node_id_config_mapping, | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| answer_dependencies=answer_dependencies | |||
| ) | |||
| @@ -0,0 +1,221 @@ | |||
| import logging | |||
| from collections.abc import Generator | |||
| from typing import Optional, cast | |||
| from core.file.file_obj import FileVar | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| GraphEngineEvent, | |||
| NodeRunStartedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| NodeRunSucceededEvent, | |||
| ) | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.nodes.answer.base_stream_processor import StreamProcessor | |||
| from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk | |||
| logger = logging.getLogger(__name__) | |||
| class AnswerStreamProcessor(StreamProcessor): | |||
| def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: | |||
| super().__init__(graph, variable_pool) | |||
| self.generate_routes = graph.answer_stream_generate_routes | |||
| self.route_position = {} | |||
| for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): | |||
| self.route_position[answer_node_id] = 0 | |||
| self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} | |||
| def process(self, | |||
| generator: Generator[GraphEngineEvent, None, None] | |||
| ) -> Generator[GraphEngineEvent, None, None]: | |||
| for event in generator: | |||
| if isinstance(event, NodeRunStartedEvent): | |||
| if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: | |||
| self.reset() | |||
| yield event | |||
| elif isinstance(event, NodeRunStreamChunkEvent): | |||
| if event.in_iteration_id: | |||
| yield event | |||
| continue | |||
| if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: | |||
| stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[ | |||
| event.route_node_state.node_id | |||
| ] | |||
| else: | |||
| stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event) | |||
| self.current_stream_chunk_generating_node_ids[ | |||
| event.route_node_state.node_id | |||
| ] = stream_out_answer_node_ids | |||
| for _ in stream_out_answer_node_ids: | |||
| yield event | |||
| elif isinstance(event, NodeRunSucceededEvent): | |||
| yield event | |||
| if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: | |||
| # update self.route_position after all stream event finished | |||
| for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: | |||
| self.route_position[answer_node_id] += 1 | |||
| del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] | |||
| # remove unreachable nodes | |||
| self._remove_unreachable_nodes(event) | |||
| # generate stream outputs | |||
| yield from self._generate_stream_outputs_when_node_finished(event) | |||
| else: | |||
| yield event | |||
| def reset(self) -> None: | |||
| self.route_position = {} | |||
| for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): | |||
| self.route_position[answer_node_id] = 0 | |||
| self.rest_node_ids = self.graph.node_ids.copy() | |||
| self.current_stream_chunk_generating_node_ids = {} | |||
| def _generate_stream_outputs_when_node_finished(self, | |||
| event: NodeRunSucceededEvent | |||
| ) -> Generator[GraphEngineEvent, None, None]: | |||
| """ | |||
| Generate stream outputs. | |||
| :param event: node run succeeded event | |||
| :return: | |||
| """ | |||
| for answer_node_id, position in self.route_position.items(): | |||
| # all depends on answer node id not in rest node ids | |||
| if (event.route_node_state.node_id != answer_node_id | |||
| and (answer_node_id not in self.rest_node_ids | |||
| or not all(dep_id not in self.rest_node_ids | |||
| for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))): | |||
| continue | |||
| route_position = self.route_position[answer_node_id] | |||
| route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:] | |||
| for route_chunk in route_chunks: | |||
| if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT: | |||
| route_chunk = cast(TextGenerateRouteChunk, route_chunk) | |||
| yield NodeRunStreamChunkEvent( | |||
| id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| chunk_content=route_chunk.text, | |||
| route_node_state=event.route_node_state, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| ) | |||
| else: | |||
| route_chunk = cast(VarGenerateRouteChunk, route_chunk) | |||
| value_selector = route_chunk.value_selector | |||
| if not value_selector: | |||
| break | |||
| value = self.variable_pool.get( | |||
| value_selector | |||
| ) | |||
| if value is None: | |||
| break | |||
| text = value.markdown | |||
| if text: | |||
| yield NodeRunStreamChunkEvent( | |||
| id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| chunk_content=text, | |||
| from_variable_selector=value_selector, | |||
| route_node_state=event.route_node_state, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| ) | |||
| self.route_position[answer_node_id] += 1 | |||
| def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: | |||
| """ | |||
| Is stream out support | |||
| :param event: queue text chunk event | |||
| :return: | |||
| """ | |||
| if not event.from_variable_selector: | |||
| return [] | |||
| stream_output_value_selector = event.from_variable_selector | |||
| if not stream_output_value_selector: | |||
| return [] | |||
| stream_out_answer_node_ids = [] | |||
| for answer_node_id, route_position in self.route_position.items(): | |||
| if answer_node_id not in self.rest_node_ids: | |||
| continue | |||
| # all depends on answer node id not in rest node ids | |||
| if all(dep_id not in self.rest_node_ids | |||
| for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): | |||
| if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): | |||
| continue | |||
| route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] | |||
| if route_chunk.type != GenerateRouteChunk.ChunkType.VAR: | |||
| continue | |||
| route_chunk = cast(VarGenerateRouteChunk, route_chunk) | |||
| value_selector = route_chunk.value_selector | |||
| # check chunk node id is before current node id or equal to current node id | |||
| if value_selector != stream_output_value_selector: | |||
| continue | |||
| stream_out_answer_node_ids.append(answer_node_id) | |||
| return stream_out_answer_node_ids | |||
| @classmethod | |||
| def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]: | |||
| """ | |||
| Fetch files from variable value | |||
| :param value: variable value | |||
| :return: | |||
| """ | |||
| if not value: | |||
| return [] | |||
| files = [] | |||
| if isinstance(value, list): | |||
| for item in value: | |||
| file_var = cls._get_file_var_from_value(item) | |||
| if file_var: | |||
| files.append(file_var) | |||
| elif isinstance(value, dict): | |||
| file_var = cls._get_file_var_from_value(value) | |||
| if file_var: | |||
| files.append(file_var) | |||
| return files | |||
| @classmethod | |||
| def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]: | |||
| """ | |||
| Get file var from value | |||
| :param value: variable value | |||
| :return: | |||
| """ | |||
| if not value: | |||
| return None | |||
| if isinstance(value, dict): | |||
| if '__variant' in value and value['__variant'] == FileVar.__name__: | |||
| return value | |||
| elif isinstance(value, FileVar): | |||
| return value.to_dict() | |||
| return None | |||
| @@ -0,0 +1,71 @@ | |||
| from abc import ABC, abstractmethod | |||
| from collections.abc import Generator | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| class StreamProcessor(ABC): | |||
| def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: | |||
| self.graph = graph | |||
| self.variable_pool = variable_pool | |||
| self.rest_node_ids = graph.node_ids.copy() | |||
| @abstractmethod | |||
| def process(self, | |||
| generator: Generator[GraphEngineEvent, None, None] | |||
| ) -> Generator[GraphEngineEvent, None, None]: | |||
| raise NotImplementedError | |||
| def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: | |||
| finished_node_id = event.route_node_state.node_id | |||
| if finished_node_id not in self.rest_node_ids: | |||
| return | |||
| # remove finished node id | |||
| self.rest_node_ids.remove(finished_node_id) | |||
| run_result = event.route_node_state.node_run_result | |||
| if not run_result: | |||
| return | |||
| if run_result.edge_source_handle: | |||
| reachable_node_ids = [] | |||
| unreachable_first_node_ids = [] | |||
| for edge in self.graph.edge_mapping[finished_node_id]: | |||
| if (edge.run_condition | |||
| and edge.run_condition.branch_identify | |||
| and run_result.edge_source_handle == edge.run_condition.branch_identify): | |||
| reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) | |||
| continue | |||
| else: | |||
| unreachable_first_node_ids.append(edge.target_node_id) | |||
| for node_id in unreachable_first_node_ids: | |||
| self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) | |||
| def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: | |||
| node_ids = [] | |||
| for edge in self.graph.edge_mapping.get(node_id, []): | |||
| if edge.target_node_id == self.graph.root_node_id: | |||
| continue | |||
| node_ids.append(edge.target_node_id) | |||
| node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) | |||
| return node_ids | |||
| def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: | |||
| """ | |||
| remove target node ids until merge | |||
| """ | |||
| if node_id not in self.rest_node_ids: | |||
| return | |||
| self.rest_node_ids.remove(node_id) | |||
| for edge in self.graph.edge_mapping.get(node_id, []): | |||
| if edge.target_node_id in reachable_node_ids: | |||
| continue | |||
| self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids) | |||
| @@ -1,5 +1,6 @@ | |||
| from enum import Enum | |||
| from pydantic import BaseModel | |||
| from pydantic import BaseModel, Field | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| @@ -8,27 +9,54 @@ class AnswerNodeData(BaseNodeData): | |||
| """ | |||
| Answer Node Data. | |||
| """ | |||
| answer: str | |||
| answer: str = Field(..., description="answer template string") | |||
| class GenerateRouteChunk(BaseModel): | |||
| """ | |||
| Generate Route Chunk. | |||
| """ | |||
| type: str | |||
| class ChunkType(Enum): | |||
| VAR = "var" | |||
| TEXT = "text" | |||
| type: ChunkType = Field(..., description="generate route chunk type") | |||
| class VarGenerateRouteChunk(GenerateRouteChunk): | |||
| """ | |||
| Var Generate Route Chunk. | |||
| """ | |||
| type: str = "var" | |||
| value_selector: list[str] | |||
| type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR | |||
| """generate route chunk type""" | |||
| value_selector: list[str] = Field(..., description="value selector") | |||
| class TextGenerateRouteChunk(GenerateRouteChunk): | |||
| """ | |||
| Text Generate Route Chunk. | |||
| """ | |||
| type: str = "text" | |||
| text: str | |||
| type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT | |||
| """generate route chunk type""" | |||
| text: str = Field(..., description="text") | |||
| class AnswerNodeDoubleLink(BaseModel): | |||
| node_id: str = Field(..., description="node id") | |||
| source_node_ids: list[str] = Field(..., description="source node ids") | |||
| target_node_ids: list[str] = Field(..., description="target node ids") | |||
| class AnswerStreamGenerateRoute(BaseModel): | |||
| """ | |||
| AnswerStreamGenerateRoute entity | |||
| """ | |||
| answer_dependencies: dict[str, list[str]] = Field( | |||
| ..., | |||
| description="answer dependencies (answer node id -> dependent answer node ids)" | |||
| ) | |||
| answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( | |||
| ..., | |||
| description="answer generate route (answer node id -> generate route chunks)" | |||
| ) | |||
| @@ -1,142 +1,103 @@ | |||
| from abc import ABC, abstractmethod | |||
| from collections.abc import Mapping, Sequence | |||
| from enum import Enum | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Any, Optional | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from models import WorkflowNodeExecutionStatus | |||
| class UserFrom(Enum): | |||
| """ | |||
| User from | |||
| """ | |||
| ACCOUNT = "account" | |||
| END_USER = "end-user" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "UserFrom": | |||
| """ | |||
| Value of | |||
| :param value: value | |||
| :return: | |||
| """ | |||
| for item in cls: | |||
| if item.value == value: | |||
| return item | |||
| raise ValueError(f"Invalid value: {value}") | |||
| from core.workflow.graph_engine.entities.event import InNodeEvent | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes.event import RunCompletedEvent, RunEvent | |||
| class BaseNode(ABC): | |||
| _node_data_cls: type[BaseNodeData] | |||
| _node_type: NodeType | |||
| tenant_id: str | |||
| app_id: str | |||
| workflow_id: str | |||
| user_id: str | |||
| user_from: UserFrom | |||
| invoke_from: InvokeFrom | |||
| workflow_call_depth: int | |||
| node_id: str | |||
| node_data: BaseNodeData | |||
| node_run_result: Optional[NodeRunResult] = None | |||
| callbacks: Sequence[WorkflowCallback] | |||
| is_answer_previous_node: bool = False | |||
| def __init__(self, tenant_id: str, | |||
| app_id: str, | |||
| workflow_id: str, | |||
| user_id: str, | |||
| user_from: UserFrom, | |||
| invoke_from: InvokeFrom, | |||
| def __init__(self, | |||
| id: str, | |||
| config: Mapping[str, Any], | |||
| callbacks: Sequence[WorkflowCallback] | None = None, | |||
| workflow_call_depth: int = 0) -> None: | |||
| self.tenant_id = tenant_id | |||
| self.app_id = app_id | |||
| self.workflow_id = workflow_id | |||
| self.user_id = user_id | |||
| self.user_from = user_from | |||
| self.invoke_from = invoke_from | |||
| self.workflow_call_depth = workflow_call_depth | |||
| # TODO: May need to check if key exists. | |||
| self.node_id = config["id"] | |||
| if not self.node_id: | |||
| graph_init_params: GraphInitParams, | |||
| graph: Graph, | |||
| graph_runtime_state: GraphRuntimeState, | |||
| previous_node_id: Optional[str] = None, | |||
| thread_pool_id: Optional[str] = None) -> None: | |||
| self.id = id | |||
| self.tenant_id = graph_init_params.tenant_id | |||
| self.app_id = graph_init_params.app_id | |||
| self.workflow_type = graph_init_params.workflow_type | |||
| self.workflow_id = graph_init_params.workflow_id | |||
| self.graph_config = graph_init_params.graph_config | |||
| self.user_id = graph_init_params.user_id | |||
| self.user_from = graph_init_params.user_from | |||
| self.invoke_from = graph_init_params.invoke_from | |||
| self.workflow_call_depth = graph_init_params.call_depth | |||
| self.graph = graph | |||
| self.graph_runtime_state = graph_runtime_state | |||
| self.previous_node_id = previous_node_id | |||
| self.thread_pool_id = thread_pool_id | |||
| node_id = config.get("id") | |||
| if not node_id: | |||
| raise ValueError("Node ID is required.") | |||
| self.node_id = node_id | |||
| self.node_data = self._node_data_cls(**config.get("data", {})) | |||
| self.callbacks = callbacks or [] | |||
| @abstractmethod | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) \ | |||
| -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: | |||
| """ | |||
| Run node | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| def run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def run(self) -> Generator[RunEvent | InNodeEvent, None, None]: | |||
| """ | |||
| Run node entry | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| try: | |||
| result = self._run( | |||
| variable_pool=variable_pool | |||
| ) | |||
| self.node_run_result = result | |||
| return result | |||
| except Exception as e: | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=str(e), | |||
| ) | |||
| result = self._run() | |||
| def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None: | |||
| """ | |||
| Publish text chunk | |||
| :param text: chunk text | |||
| :param value_selector: value selector | |||
| :return: | |||
| """ | |||
| if self.callbacks: | |||
| for callback in self.callbacks: | |||
| callback.on_node_text_chunk( | |||
| node_id=self.node_id, | |||
| text=text, | |||
| metadata={ | |||
| "node_type": self.node_type, | |||
| "is_answer_previous_node": self.is_answer_previous_node, | |||
| "value_selector": value_selector | |||
| } | |||
| ) | |||
| if isinstance(result, NodeRunResult): | |||
| yield RunCompletedEvent( | |||
| run_result=result | |||
| ) | |||
| else: | |||
| yield from result | |||
| @classmethod | |||
| def extract_variable_selector_to_variable_mapping(cls, config: dict): | |||
| def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param config: node config | |||
| :return: | |||
| """ | |||
| node_id = config.get("id") | |||
| if not node_id: | |||
| raise ValueError("Node ID is required when extracting variable selector to variable mapping.") | |||
| node_data = cls._node_data_cls(**config.get("data", {})) | |||
| return cls._extract_variable_selector_to_variable_mapping(node_data) | |||
| return cls._extract_variable_selector_to_variable_mapping( | |||
| graph_config=graph_config, | |||
| node_id=node_id, | |||
| node_data=node_data | |||
| ) | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: BaseNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| @@ -158,38 +119,3 @@ class BaseNode(ABC): | |||
| :return: | |||
| """ | |||
| return self._node_type | |||
| class BaseIterationNode(BaseNode): | |||
| @abstractmethod | |||
| def _run(self, variable_pool: VariablePool) -> BaseIterationState: | |||
| """ | |||
| Run node | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| def run(self, variable_pool: VariablePool) -> BaseIterationState: | |||
| """ | |||
| Run node entry | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| return self._run(variable_pool=variable_pool) | |||
| def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: | |||
| """ | |||
| Get next iteration start node id based on the graph. | |||
| :param graph: graph | |||
| :return: next node id | |||
| """ | |||
| return self._get_next_iteration(variable_pool, state) | |||
| @abstractmethod | |||
| def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: | |||
| """ | |||
| Get next iteration start node id based on the graph. | |||
| :param graph: graph | |||
| :return: next node id | |||
| """ | |||
| raise NotImplementedError | |||
| @@ -1,4 +1,5 @@ | |||
| from typing import Optional, Union, cast | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Optional, Union, cast | |||
| from configs import dify_config | |||
| from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage | |||
| @@ -6,7 +7,6 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider | |||
| from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider | |||
| from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.code.entities import CodeNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| @@ -33,13 +33,13 @@ class CodeNode(BaseNode): | |||
| return code_provider.get_default_config() | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run code | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| node_data = cast(CodeNodeData, self.node_data) | |||
| node_data = self.node_data | |||
| node_data = cast(CodeNodeData, node_data) | |||
| # Get code language | |||
| code_language = node_data.code_language | |||
| @@ -49,7 +49,7 @@ class CodeNode(BaseNode): | |||
| variables = {} | |||
| for variable_selector in node_data.variables: | |||
| variable = variable_selector.variable | |||
| value = variable_pool.get_any(variable_selector.value_selector) | |||
| value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) | |||
| variables[variable] = value | |||
| # Run code | |||
| @@ -311,13 +311,19 @@ class CodeNode(BaseNode): | |||
| return transformed_result | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: CodeNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| return { | |||
| variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables | |||
| node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables | |||
| } | |||
| @@ -1,8 +1,7 @@ | |||
| from typing import cast | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, cast | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.end.entities import EndNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| @@ -12,10 +11,9 @@ class EndNode(BaseNode): | |||
| _node_data_cls = EndNodeData | |||
| _node_type = NodeType.END | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run node | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| node_data = self.node_data | |||
| @@ -24,7 +22,7 @@ class EndNode(BaseNode): | |||
| outputs = {} | |||
| for variable_selector in output_variables: | |||
| value = variable_pool.get_any(variable_selector.value_selector) | |||
| value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) | |||
| outputs[variable_selector.variable] = value | |||
| return NodeRunResult( | |||
| @@ -34,52 +32,16 @@ class EndNode(BaseNode): | |||
| ) | |||
| @classmethod | |||
| def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]: | |||
| """ | |||
| Extract generate nodes | |||
| :param graph: graph | |||
| :param config: node config | |||
| :return: | |||
| """ | |||
| node_data = cls._node_data_cls(**config.get("data", {})) | |||
| node_data = cast(EndNodeData, node_data) | |||
| return cls.extract_generate_nodes_from_node_data(graph, node_data) | |||
| @classmethod | |||
| def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]: | |||
| """ | |||
| Extract generate nodes from node data | |||
| :param graph: graph | |||
| :param node_data: node data object | |||
| :return: | |||
| """ | |||
| nodes = graph.get('nodes', []) | |||
| node_mapping = {node.get('id'): node for node in nodes} | |||
| variable_selectors = node_data.outputs | |||
| generate_nodes = [] | |||
| for variable_selector in variable_selectors: | |||
| if not variable_selector.value_selector: | |||
| continue | |||
| node_id = variable_selector.value_selector[0] | |||
| if node_id != 'sys' and node_id in node_mapping: | |||
| node = node_mapping[node_id] | |||
| node_type = node.get('data', {}).get('type') | |||
| if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text': | |||
| generate_nodes.append(node_id) | |||
| # remove duplicates | |||
| generate_nodes = list(set(generate_nodes)) | |||
| return generate_nodes | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: EndNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| @@ -0,0 +1,148 @@ | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam | |||
| class EndStreamGeneratorRouter: | |||
| @classmethod | |||
| def init(cls, | |||
| node_id_config_mapping: dict[str, dict], | |||
| reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] | |||
| node_parallel_mapping: dict[str, str] | |||
| ) -> EndStreamParam: | |||
| """ | |||
| Get stream generate routes. | |||
| :return: | |||
| """ | |||
| # parse stream output node value selector of end nodes | |||
| end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} | |||
| for end_node_id, node_config in node_id_config_mapping.items(): | |||
| if not node_config.get('data', {}).get('type') == NodeType.END.value: | |||
| continue | |||
| # skip end node in parallel | |||
| if end_node_id in node_parallel_mapping: | |||
| continue | |||
| # get generate route for stream output | |||
| stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config) | |||
| end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors | |||
| # fetch end dependencies | |||
| end_node_ids = list(end_stream_variable_selectors_mapping.keys()) | |||
| end_dependencies = cls._fetch_ends_dependencies( | |||
| end_node_ids=end_node_ids, | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| node_id_config_mapping=node_id_config_mapping | |||
| ) | |||
| return EndStreamParam( | |||
| end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping, | |||
| end_dependencies=end_dependencies | |||
| ) | |||
| @classmethod | |||
| def extract_stream_variable_selector_from_node_data(cls, | |||
| node_id_config_mapping: dict[str, dict], | |||
| node_data: EndNodeData) -> list[list[str]]: | |||
| """ | |||
| Extract stream variable selector from node data | |||
| :param node_id_config_mapping: node id config mapping | |||
| :param node_data: node data object | |||
| :return: | |||
| """ | |||
| variable_selectors = node_data.outputs | |||
| value_selectors = [] | |||
| for variable_selector in variable_selectors: | |||
| if not variable_selector.value_selector: | |||
| continue | |||
| node_id = variable_selector.value_selector[0] | |||
| if node_id != 'sys' and node_id in node_id_config_mapping: | |||
| node = node_id_config_mapping[node_id] | |||
| node_type = node.get('data', {}).get('type') | |||
| if ( | |||
| variable_selector.value_selector not in value_selectors | |||
| and node_type == NodeType.LLM.value | |||
| and variable_selector.value_selector[1] == 'text' | |||
| ): | |||
| value_selectors.append(variable_selector.value_selector) | |||
| return value_selectors | |||
| @classmethod | |||
| def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \ | |||
| -> list[list[str]]: | |||
| """ | |||
| Extract stream variable selector from node config | |||
| :param node_id_config_mapping: node id config mapping | |||
| :param config: node config | |||
| :return: | |||
| """ | |||
| node_data = EndNodeData(**config.get("data", {})) | |||
| return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data) | |||
| @classmethod | |||
| def _fetch_ends_dependencies(cls, | |||
| end_node_ids: list[str], | |||
| reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] | |||
| node_id_config_mapping: dict[str, dict] | |||
| ) -> dict[str, list[str]]: | |||
| """ | |||
| Fetch end dependencies | |||
| :param end_node_ids: end node ids | |||
| :param reverse_edge_mapping: reverse edge mapping | |||
| :param node_id_config_mapping: node id config mapping | |||
| :return: | |||
| """ | |||
| end_dependencies: dict[str, list[str]] = {} | |||
| for end_node_id in end_node_ids: | |||
| if end_dependencies.get(end_node_id) is None: | |||
| end_dependencies[end_node_id] = [] | |||
| cls._recursive_fetch_end_dependencies( | |||
| current_node_id=end_node_id, | |||
| end_node_id=end_node_id, | |||
| node_id_config_mapping=node_id_config_mapping, | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| end_dependencies=end_dependencies | |||
| ) | |||
| return end_dependencies | |||
| @classmethod | |||
| def _recursive_fetch_end_dependencies(cls, | |||
| current_node_id: str, | |||
| end_node_id: str, | |||
| node_id_config_mapping: dict[str, dict], | |||
| reverse_edge_mapping: dict[str, list["GraphEdge"]], | |||
| # type: ignore[name-defined] | |||
| end_dependencies: dict[str, list[str]] | |||
| ) -> None: | |||
| """ | |||
| Recursive fetch end dependencies | |||
| :param current_node_id: current node id | |||
| :param end_node_id: end node id | |||
| :param node_id_config_mapping: node id config mapping | |||
| :param reverse_edge_mapping: reverse edge mapping | |||
| :param end_dependencies: end dependencies | |||
| :return: | |||
| """ | |||
| reverse_edges = reverse_edge_mapping.get(current_node_id, []) | |||
| for edge in reverse_edges: | |||
| source_node_id = edge.source_node_id | |||
| source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') | |||
| if source_node_type in ( | |||
| NodeType.IF_ELSE.value, | |||
| NodeType.QUESTION_CLASSIFIER, | |||
| ): | |||
| end_dependencies[end_node_id].append(source_node_id) | |||
| else: | |||
| cls._recursive_fetch_end_dependencies( | |||
| current_node_id=source_node_id, | |||
| end_node_id=end_node_id, | |||
| node_id_config_mapping=node_id_config_mapping, | |||
| reverse_edge_mapping=reverse_edge_mapping, | |||
| end_dependencies=end_dependencies | |||
| ) | |||
| @@ -0,0 +1,191 @@ | |||
| import logging | |||
| from collections.abc import Generator | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| GraphEngineEvent, | |||
| NodeRunStartedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| NodeRunSucceededEvent, | |||
| ) | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.nodes.answer.base_stream_processor import StreamProcessor | |||
| logger = logging.getLogger(__name__) | |||
| class EndStreamProcessor(StreamProcessor): | |||
| def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: | |||
| super().__init__(graph, variable_pool) | |||
| self.end_stream_param = graph.end_stream_param | |||
| self.route_position = {} | |||
| for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): | |||
| self.route_position[end_node_id] = 0 | |||
| self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} | |||
| self.has_outputed = False | |||
| self.outputed_node_ids = set() | |||
| def process(self, | |||
| generator: Generator[GraphEngineEvent, None, None] | |||
| ) -> Generator[GraphEngineEvent, None, None]: | |||
| for event in generator: | |||
| if isinstance(event, NodeRunStartedEvent): | |||
| if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: | |||
| self.reset() | |||
| yield event | |||
| elif isinstance(event, NodeRunStreamChunkEvent): | |||
| if event.in_iteration_id: | |||
| if self.has_outputed and event.node_id not in self.outputed_node_ids: | |||
| event.chunk_content = '\n' + event.chunk_content | |||
| self.outputed_node_ids.add(event.node_id) | |||
| self.has_outputed = True | |||
| yield event | |||
| continue | |||
| if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: | |||
| stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[ | |||
| event.route_node_state.node_id | |||
| ] | |||
| else: | |||
| stream_out_end_node_ids = self._get_stream_out_end_node_ids(event) | |||
| self.current_stream_chunk_generating_node_ids[ | |||
| event.route_node_state.node_id | |||
| ] = stream_out_end_node_ids | |||
| if stream_out_end_node_ids: | |||
| if self.has_outputed and event.node_id not in self.outputed_node_ids: | |||
| event.chunk_content = '\n' + event.chunk_content | |||
| self.outputed_node_ids.add(event.node_id) | |||
| self.has_outputed = True | |||
| yield event | |||
| elif isinstance(event, NodeRunSucceededEvent): | |||
| yield event | |||
| if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: | |||
| # update self.route_position after all stream event finished | |||
| for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: | |||
| self.route_position[end_node_id] += 1 | |||
| del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] | |||
| # remove unreachable nodes | |||
| self._remove_unreachable_nodes(event) | |||
| # generate stream outputs | |||
| yield from self._generate_stream_outputs_when_node_finished(event) | |||
| else: | |||
| yield event | |||
| def reset(self) -> None: | |||
| self.route_position = {} | |||
| for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): | |||
| self.route_position[end_node_id] = 0 | |||
| self.rest_node_ids = self.graph.node_ids.copy() | |||
| self.current_stream_chunk_generating_node_ids = {} | |||
| def _generate_stream_outputs_when_node_finished(self, | |||
| event: NodeRunSucceededEvent | |||
| ) -> Generator[GraphEngineEvent, None, None]: | |||
| """ | |||
| Generate stream outputs. | |||
| :param event: node run succeeded event | |||
| :return: | |||
| """ | |||
| for end_node_id, position in self.route_position.items(): | |||
| # all depends on end node id not in rest node ids | |||
| if (event.route_node_state.node_id != end_node_id | |||
| and (end_node_id not in self.rest_node_ids | |||
| or not all(dep_id not in self.rest_node_ids | |||
| for dep_id in self.end_stream_param.end_dependencies[end_node_id]))): | |||
| continue | |||
| route_position = self.route_position[end_node_id] | |||
| position = 0 | |||
| value_selectors = [] | |||
| for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: | |||
| if position >= route_position: | |||
| value_selectors.append(current_value_selectors) | |||
| position += 1 | |||
| for value_selector in value_selectors: | |||
| if not value_selector: | |||
| continue | |||
| value = self.variable_pool.get( | |||
| value_selector | |||
| ) | |||
| if value is None: | |||
| break | |||
| text = value.markdown | |||
| if text: | |||
| current_node_id = value_selector[0] | |||
| if self.has_outputed and current_node_id not in self.outputed_node_ids: | |||
| text = '\n' + text | |||
| self.outputed_node_ids.add(current_node_id) | |||
| self.has_outputed = True | |||
| yield NodeRunStreamChunkEvent( | |||
| id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| chunk_content=text, | |||
| from_variable_selector=value_selector, | |||
| route_node_state=event.route_node_state, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| ) | |||
| self.route_position[end_node_id] += 1 | |||
| def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: | |||
| """ | |||
| Is stream out support | |||
| :param event: queue text chunk event | |||
| :return: | |||
| """ | |||
| if not event.from_variable_selector: | |||
| return [] | |||
| stream_output_value_selector = event.from_variable_selector | |||
| if not stream_output_value_selector: | |||
| return [] | |||
| stream_out_end_node_ids = [] | |||
| for end_node_id, route_position in self.route_position.items(): | |||
| if end_node_id not in self.rest_node_ids: | |||
| continue | |||
| # all depends on end node id not in rest node ids | |||
| if all(dep_id not in self.rest_node_ids | |||
| for dep_id in self.end_stream_param.end_dependencies[end_node_id]): | |||
| if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]): | |||
| continue | |||
| position = 0 | |||
| value_selector = None | |||
| for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: | |||
| if position == route_position: | |||
| value_selector = current_value_selectors | |||
| break | |||
| position += 1 | |||
| if not value_selector: | |||
| continue | |||
| # check chunk node id is before current node id or equal to current node id | |||
| if value_selector != stream_output_value_selector: | |||
| continue | |||
| stream_out_end_node_ids.append(end_node_id) | |||
| return stream_out_end_node_ids | |||
| @@ -1,3 +1,5 @@ | |||
| from pydantic import BaseModel, Field | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.variable_entities import VariableSelector | |||
| @@ -7,3 +9,17 @@ class EndNodeData(BaseNodeData): | |||
| END Node Data. | |||
| """ | |||
| outputs: list[VariableSelector] | |||
| class EndStreamParam(BaseModel): | |||
| """ | |||
| EndStreamParam entity | |||
| """ | |||
| end_dependencies: dict[str, list[str]] = Field( | |||
| ..., | |||
| description="end dependencies (end node id -> dependent node ids)" | |||
| ) | |||
| end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( | |||
| ..., | |||
| description="end stream variable selector mapping (end node id -> stream variable selectors)" | |||
| ) | |||
| @@ -0,0 +1,20 @@ | |||
| from pydantic import BaseModel, Field | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| class RunCompletedEvent(BaseModel): | |||
| run_result: NodeRunResult = Field(..., description="run result") | |||
| class RunStreamChunkEvent(BaseModel): | |||
| chunk_content: str = Field(..., description="chunk content") | |||
| from_variable_selector: list[str] = Field(..., description="from variable selector") | |||
| class RunRetrieverResourceEvent(BaseModel): | |||
| retriever_resources: list[dict] = Field(..., description="retriever resources") | |||
| context: str = Field(..., description="context") | |||
| RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent | |||
| @@ -1,15 +1,14 @@ | |||
| import logging | |||
| from collections.abc import Mapping, Sequence | |||
| from mimetypes import guess_extension | |||
| from os import path | |||
| from typing import cast | |||
| from typing import Any, cast | |||
| from configs import dify_config | |||
| from core.app.segments import parser | |||
| from core.file.file_obj import FileTransferMethod, FileType, FileVar | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.http_request.entities import ( | |||
| HttpRequestNodeData, | |||
| @@ -48,17 +47,22 @@ class HttpRequestNode(BaseNode): | |||
| }, | |||
| } | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data) | |||
| # TODO: Switch to use segment directly | |||
| if node_data.authorization.config and node_data.authorization.config.api_key: | |||
| node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text | |||
| node_data.authorization.config.api_key = parser.convert_template( | |||
| template=node_data.authorization.config.api_key, | |||
| variable_pool=self.graph_runtime_state.variable_pool | |||
| ).text | |||
| # init http executor | |||
| http_executor = None | |||
| try: | |||
| http_executor = HttpExecutor( | |||
| node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool | |||
| node_data=node_data, | |||
| timeout=self._get_request_timeout(node_data), | |||
| variable_pool=self.graph_runtime_state.variable_pool | |||
| ) | |||
| # invoke http executor | |||
| @@ -102,13 +106,19 @@ class HttpRequestNode(BaseNode): | |||
| return timeout | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: HttpRequestNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| node_data = cast(HttpRequestNodeData, node_data) | |||
| try: | |||
| http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT) | |||
| @@ -116,7 +126,7 @@ class HttpRequestNode(BaseNode): | |||
| variable_mapping = {} | |||
| for variable_selector in variable_selectors: | |||
| variable_mapping[variable_selector.variable] = variable_selector.value_selector | |||
| variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector | |||
| return variable_mapping | |||
| except Exception as e: | |||
| @@ -3,20 +3,7 @@ from typing import Literal, Optional | |||
| from pydantic import BaseModel | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| class Condition(BaseModel): | |||
| """ | |||
| Condition entity | |||
| """ | |||
| variable_selector: list[str] | |||
| comparison_operator: Literal[ | |||
| # for string or array | |||
| "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "regex match", | |||
| # for number | |||
| "=", "≠", ">", "<", "≥", "≤", "null", "not null" | |||
| ] | |||
| value: Optional[str] = None | |||
| from core.workflow.utils.condition.entities import Condition | |||
| class IfElseNodeData(BaseNodeData): | |||
| @@ -1,13 +1,10 @@ | |||
| import re | |||
| from collections.abc import Sequence | |||
| from typing import Optional, cast | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, cast | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.if_else.entities import Condition, IfElseNodeData | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from core.workflow.nodes.if_else.entities import IfElseNodeData | |||
| from core.workflow.utils.condition.processor import ConditionProcessor | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| @@ -15,31 +12,35 @@ class IfElseNode(BaseNode): | |||
| _node_data_cls = IfElseNodeData | |||
| _node_type = NodeType.IF_ELSE | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run node | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| node_data = self.node_data | |||
| node_data = cast(IfElseNodeData, node_data) | |||
| node_inputs = { | |||
| node_inputs: dict[str, list] = { | |||
| "conditions": [] | |||
| } | |||
| process_datas = { | |||
| process_datas: dict[str, list] = { | |||
| "condition_results": [] | |||
| } | |||
| input_conditions = [] | |||
| final_result = False | |||
| selected_case_id = None | |||
| condition_processor = ConditionProcessor() | |||
| try: | |||
| # Check if the new cases structure is used | |||
| if node_data.cases: | |||
| for case in node_data.cases: | |||
| input_conditions, group_result = self.process_conditions(variable_pool, case.conditions) | |||
| input_conditions, group_result = condition_processor.process_conditions( | |||
| variable_pool=self.graph_runtime_state.variable_pool, | |||
| conditions=case.conditions | |||
| ) | |||
| # Apply the logical operator for the current case | |||
| final_result = all(group_result) if case.logical_operator == "and" else any(group_result) | |||
| @@ -58,7 +59,10 @@ class IfElseNode(BaseNode): | |||
| else: | |||
| # Fallback to old structure if cases are not defined | |||
| input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions) | |||
| input_conditions, group_result = condition_processor.process_conditions( | |||
| variable_pool=self.graph_runtime_state.variable_pool, | |||
| conditions=node_data.conditions | |||
| ) | |||
| final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result) | |||
| @@ -94,376 +98,17 @@ class IfElseNode(BaseNode): | |||
| return data | |||
| def evaluate_condition( | |||
| self, actual_value: Optional[str | list], expected_value: str, comparison_operator: str | |||
| ) -> bool: | |||
| """ | |||
| Evaluate condition | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :param comparison_operator: comparison operator | |||
| :return: bool | |||
| """ | |||
| if comparison_operator == "contains": | |||
| return self._assert_contains(actual_value, expected_value) | |||
| elif comparison_operator == "not contains": | |||
| return self._assert_not_contains(actual_value, expected_value) | |||
| elif comparison_operator == "start with": | |||
| return self._assert_start_with(actual_value, expected_value) | |||
| elif comparison_operator == "end with": | |||
| return self._assert_end_with(actual_value, expected_value) | |||
| elif comparison_operator == "is": | |||
| return self._assert_is(actual_value, expected_value) | |||
| elif comparison_operator == "is not": | |||
| return self._assert_is_not(actual_value, expected_value) | |||
| elif comparison_operator == "empty": | |||
| return self._assert_empty(actual_value) | |||
| elif comparison_operator == "not empty": | |||
| return self._assert_not_empty(actual_value) | |||
| elif comparison_operator == "=": | |||
| return self._assert_equal(actual_value, expected_value) | |||
| elif comparison_operator == "≠": | |||
| return self._assert_not_equal(actual_value, expected_value) | |||
| elif comparison_operator == ">": | |||
| return self._assert_greater_than(actual_value, expected_value) | |||
| elif comparison_operator == "<": | |||
| return self._assert_less_than(actual_value, expected_value) | |||
| elif comparison_operator == "≥": | |||
| return self._assert_greater_than_or_equal(actual_value, expected_value) | |||
| elif comparison_operator == "≤": | |||
| return self._assert_less_than_or_equal(actual_value, expected_value) | |||
| elif comparison_operator == "null": | |||
| return self._assert_null(actual_value) | |||
| elif comparison_operator == "not null": | |||
| return self._assert_not_null(actual_value) | |||
| elif comparison_operator == "regex match": | |||
| return self._assert_regex_match(actual_value, expected_value) | |||
| else: | |||
| raise ValueError(f"Invalid comparison operator: {comparison_operator}") | |||
| def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): | |||
| input_conditions = [] | |||
| group_result = [] | |||
| for condition in conditions: | |||
| actual_variable = variable_pool.get_any(condition.variable_selector) | |||
| if condition.value is not None: | |||
| variable_template_parser = VariableTemplateParser(template=condition.value) | |||
| expected_value = variable_template_parser.extract_variable_selectors() | |||
| variable_selectors = variable_template_parser.extract_variable_selectors() | |||
| if variable_selectors: | |||
| for variable_selector in variable_selectors: | |||
| value = variable_pool.get_any(variable_selector.value_selector) | |||
| expected_value = variable_template_parser.format({variable_selector.variable: value}) | |||
| else: | |||
| expected_value = condition.value | |||
| else: | |||
| expected_value = None | |||
| comparison_operator = condition.comparison_operator | |||
| input_conditions.append( | |||
| { | |||
| "actual_value": actual_variable, | |||
| "expected_value": expected_value, | |||
| "comparison_operator": comparison_operator | |||
| } | |||
| ) | |||
| result = self.evaluate_condition(actual_variable, expected_value, comparison_operator) | |||
| group_result.append(result) | |||
| return input_conditions, group_result | |||
| def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: | |||
| """ | |||
| Assert contains | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if not actual_value: | |||
| return False | |||
| if not isinstance(actual_value, str | list): | |||
| raise ValueError('Invalid actual value type: string or array') | |||
| if expected_value not in actual_value: | |||
| return False | |||
| return True | |||
| def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: | |||
| """ | |||
| Assert not contains | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if not actual_value: | |||
| return True | |||
| if not isinstance(actual_value, str | list): | |||
| raise ValueError('Invalid actual value type: string or array') | |||
| if expected_value in actual_value: | |||
| return False | |||
| return True | |||
| def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: | |||
| """ | |||
| Assert start with | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if not actual_value: | |||
| return False | |||
| if not isinstance(actual_value, str): | |||
| raise ValueError('Invalid actual value type: string') | |||
| if not actual_value.startswith(expected_value): | |||
| return False | |||
| return True | |||
| def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: | |||
| """ | |||
| Assert end with | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if not actual_value: | |||
| return False | |||
| if not isinstance(actual_value, str): | |||
| raise ValueError('Invalid actual value type: string') | |||
| if not actual_value.endswith(expected_value): | |||
| return False | |||
| return True | |||
| def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: | |||
| """ | |||
| Assert is | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, str): | |||
| raise ValueError('Invalid actual value type: string') | |||
| if actual_value != expected_value: | |||
| return False | |||
| return True | |||
| def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: | |||
| """ | |||
| Assert is not | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, str): | |||
| raise ValueError('Invalid actual value type: string') | |||
| if actual_value == expected_value: | |||
| return False | |||
| return True | |||
| def _assert_empty(self, actual_value: Optional[str]) -> bool: | |||
| """ | |||
| Assert empty | |||
| :param actual_value: actual value | |||
| :return: | |||
| """ | |||
| if not actual_value: | |||
| return True | |||
| return False | |||
| def _assert_regex_match(self, actual_value: Optional[str], expected_value: str) -> bool: | |||
| """ | |||
| Assert empty | |||
| :param actual_value: actual value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| return re.search(expected_value, actual_value) is not None | |||
| def _assert_not_empty(self, actual_value: Optional[str]) -> bool: | |||
| """ | |||
| Assert not empty | |||
| :param actual_value: actual value | |||
| :return: | |||
| """ | |||
| if actual_value: | |||
| return True | |||
| return False | |||
| def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: | |||
| """ | |||
| Assert equal | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value != expected_value: | |||
| return False | |||
| return True | |||
| def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: | |||
| """ | |||
| Assert not equal | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value == expected_value: | |||
| return False | |||
| return True | |||
| def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: | |||
| """ | |||
| Assert greater than | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value <= expected_value: | |||
| return False | |||
| return True | |||
| def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: | |||
| """ | |||
| Assert less than | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value >= expected_value: | |||
| return False | |||
| return True | |||
| def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: | |||
| """ | |||
| Assert greater than or equal | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value < expected_value: | |||
| return False | |||
| return True | |||
| def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: | |||
| """ | |||
| Assert less than or equal | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value > expected_value: | |||
| return False | |||
| return True | |||
| def _assert_null(self, actual_value: Optional[int | float]) -> bool: | |||
| """ | |||
| Assert null | |||
| :param actual_value: actual value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return True | |||
| return False | |||
| def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: | |||
| """ | |||
| Assert not null | |||
| :param actual_value: actual value | |||
| :return: | |||
| """ | |||
| if actual_value is not None: | |||
| return True | |||
| return False | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: IfElseNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| @@ -1,6 +1,6 @@ | |||
| from typing import Any, Optional | |||
| from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState | |||
| from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData | |||
| class IterationNodeData(BaseIterationNodeData): | |||
| @@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData): | |||
| iterator_selector: list[str] # variable selector | |||
| output_selector: list[str] # output selector | |||
| class IterationStartNodeData(BaseNodeData): | |||
| """ | |||
| Iteration Start Node Data. | |||
| """ | |||
| pass | |||
| class IterationState(BaseIterationState): | |||
| """ | |||
| Iteration State. | |||
| @@ -1,124 +1,371 @@ | |||
| from typing import cast | |||
| import logging | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from datetime import datetime, timezone | |||
| from typing import Any, cast | |||
| from configs import dify_config | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.workflow.entities.base_node_data_entities import BaseIterationState | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseIterationNode | |||
| from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| BaseGraphEvent, | |||
| BaseNodeEvent, | |||
| BaseParallelBranchEvent, | |||
| GraphRunFailedEvent, | |||
| InNodeEvent, | |||
| IterationRunFailedEvent, | |||
| IterationRunNextEvent, | |||
| IterationRunStartedEvent, | |||
| IterationRunSucceededEvent, | |||
| NodeRunSucceededEvent, | |||
| ) | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.run_condition import RunCondition | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.event import RunCompletedEvent, RunEvent | |||
| from core.workflow.nodes.iteration.entities import IterationNodeData | |||
| from core.workflow.utils.condition.entities import Condition | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| logger = logging.getLogger(__name__) | |||
| class IterationNode(BaseIterationNode): | |||
| class IterationNode(BaseNode): | |||
| """ | |||
| Iteration Node. | |||
| """ | |||
| _node_data_cls = IterationNodeData | |||
| _node_type = NodeType.ITERATION | |||
| def _run(self, variable_pool: VariablePool) -> BaseIterationState: | |||
| def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: | |||
| """ | |||
| Run the node. | |||
| """ | |||
| self.node_data = cast(IterationNodeData, self.node_data) | |||
| iterator = variable_pool.get_any(self.node_data.iterator_selector) | |||
| if not isinstance(iterator, list): | |||
| raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.") | |||
| iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) | |||
| state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={ | |||
| 'iterator_selector': iterator | |||
| }, outputs=[], metadata=IterationState.MetaData( | |||
| iterator_length=len(iterator) if iterator is not None else 0 | |||
| )) | |||
| if not iterator_list_segment: | |||
| raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found") | |||
| self._set_current_iteration_variable(variable_pool, state) | |||
| return state | |||
| iterator_list_value = iterator_list_segment.to_object() | |||
| def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str: | |||
| """ | |||
| Get next iteration start node id based on the graph. | |||
| :param graph: graph | |||
| :return: next node id | |||
| """ | |||
| # resolve current output | |||
| self._resolve_current_output(variable_pool, state) | |||
| # move to next iteration | |||
| self._next_iteration(variable_pool, state) | |||
| node_data = cast(IterationNodeData, self.node_data) | |||
| if self._reached_iteration_limit(variable_pool, state): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| if not isinstance(iterator_list_value, list): | |||
| raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") | |||
| inputs = { | |||
| "iterator_selector": iterator_list_value | |||
| } | |||
| graph_config = self.graph_config | |||
| if not self.node_data.start_node_id: | |||
| raise ValueError(f'field start_node_id in iteration {self.node_id} not found') | |||
| root_node_id = self.node_data.start_node_id | |||
| # init graph | |||
| iteration_graph = Graph.init( | |||
| graph_config=graph_config, | |||
| root_node_id=root_node_id | |||
| ) | |||
| if not iteration_graph: | |||
| raise ValueError('iteration graph not found') | |||
| leaf_node_ids = iteration_graph.get_leaf_node_ids() | |||
| iteration_leaf_node_ids = [] | |||
| for leaf_node_id in leaf_node_ids: | |||
| node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id) | |||
| if not node_config: | |||
| continue | |||
| leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id") | |||
| if not leaf_node_iteration_id: | |||
| continue | |||
| if leaf_node_iteration_id != self.node_id: | |||
| continue | |||
| iteration_leaf_node_ids.append(leaf_node_id) | |||
| # add condition of end nodes to root node | |||
| iteration_graph.add_extra_edge( | |||
| source_node_id=leaf_node_id, | |||
| target_node_id=root_node_id, | |||
| run_condition=RunCondition( | |||
| type="condition", | |||
| conditions=[ | |||
| Condition( | |||
| variable_selector=[self.node_id, "index"], | |||
| comparison_operator="<", | |||
| value=str(len(iterator_list_value)) | |||
| ) | |||
| ] | |||
| ) | |||
| ) | |||
| variable_pool = self.graph_runtime_state.variable_pool | |||
| # append iteration variable (item, index) to variable pool | |||
| variable_pool.add( | |||
| [self.node_id, 'index'], | |||
| 0 | |||
| ) | |||
| variable_pool.add( | |||
| [self.node_id, 'item'], | |||
| iterator_list_value[0] | |||
| ) | |||
| # init graph engine | |||
| from core.workflow.graph_engine.graph_engine import GraphEngine | |||
| graph_engine = GraphEngine( | |||
| tenant_id=self.tenant_id, | |||
| app_id=self.app_id, | |||
| workflow_type=self.workflow_type, | |||
| workflow_id=self.workflow_id, | |||
| user_id=self.user_id, | |||
| user_from=self.user_from, | |||
| invoke_from=self.invoke_from, | |||
| call_depth=self.workflow_call_depth, | |||
| graph=iteration_graph, | |||
| graph_config=graph_config, | |||
| variable_pool=variable_pool, | |||
| max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, | |||
| max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME | |||
| ) | |||
| start_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| yield IterationRunStartedEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| metadata={ | |||
| "iterator_length": len(iterator_list_value) | |||
| }, | |||
| predecessor_node_id=self.previous_node_id | |||
| ) | |||
| yield IterationRunNextEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| index=0, | |||
| pre_iteration_output=None | |||
| ) | |||
| outputs: list[Any] = [] | |||
| try: | |||
| # run workflow | |||
| rst = graph_engine.run() | |||
| for event in rst: | |||
| if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: | |||
| event.in_iteration_id = self.node_id | |||
| if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START: | |||
| continue | |||
| if isinstance(event, NodeRunSucceededEvent): | |||
| if event.route_node_state.node_run_result: | |||
| metadata = event.route_node_state.node_run_result.metadata | |||
| if not metadata: | |||
| metadata = {} | |||
| if NodeRunMetadataKey.ITERATION_ID not in metadata: | |||
| metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id | |||
| metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index']) | |||
| event.route_node_state.node_run_result.metadata = metadata | |||
| yield event | |||
| # handle iteration run result | |||
| if event.route_node_state.node_id in iteration_leaf_node_ids: | |||
| # append to iteration output variable list | |||
| current_iteration_output = variable_pool.get_any(self.node_data.output_selector) | |||
| outputs.append(current_iteration_output) | |||
| # remove all nodes outputs from variable pool | |||
| for node_id in iteration_graph.node_ids: | |||
| variable_pool.remove_node(node_id) | |||
| # move to next iteration | |||
| current_index = variable_pool.get([self.node_id, 'index']) | |||
| if current_index is None: | |||
| raise ValueError(f'iteration {self.node_id} current index not found') | |||
| next_index = int(current_index.to_object()) + 1 | |||
| variable_pool.add( | |||
| [self.node_id, 'index'], | |||
| next_index | |||
| ) | |||
| if next_index < len(iterator_list_value): | |||
| variable_pool.add( | |||
| [self.node_id, 'item'], | |||
| iterator_list_value[next_index] | |||
| ) | |||
| yield IterationRunNextEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| index=next_index, | |||
| pre_iteration_output=jsonable_encoder( | |||
| current_iteration_output) if current_iteration_output else None | |||
| ) | |||
| elif isinstance(event, BaseGraphEvent): | |||
| if isinstance(event, GraphRunFailedEvent): | |||
| # iteration run failed | |||
| yield IterationRunFailedEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| outputs={ | |||
| "output": jsonable_encoder(outputs) | |||
| }, | |||
| steps=len(iterator_list_value), | |||
| metadata={ | |||
| "total_tokens": graph_engine.graph_runtime_state.total_tokens | |||
| }, | |||
| error=event.error, | |||
| ) | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=event.error, | |||
| ) | |||
| ) | |||
| break | |||
| else: | |||
| event = cast(InNodeEvent, event) | |||
| yield event | |||
| yield IterationRunSucceededEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| outputs={ | |||
| 'output': jsonable_encoder(state.outputs) | |||
| "output": jsonable_encoder(outputs) | |||
| }, | |||
| steps=len(iterator_list_value), | |||
| metadata={ | |||
| "total_tokens": graph_engine.graph_runtime_state.total_tokens | |||
| } | |||
| ) | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| outputs={ | |||
| 'output': jsonable_encoder(outputs) | |||
| } | |||
| ) | |||
| ) | |||
| except Exception as e: | |||
| # iteration run failed | |||
| logger.exception("Iteration run failed") | |||
| yield IterationRunFailedEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| outputs={ | |||
| "output": jsonable_encoder(outputs) | |||
| }, | |||
| steps=len(iterator_list_value), | |||
| metadata={ | |||
| "total_tokens": graph_engine.graph_runtime_state.total_tokens | |||
| }, | |||
| error=str(e), | |||
| ) | |||
| return node_data.start_node_id | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=str(e), | |||
| ) | |||
| ) | |||
| finally: | |||
| # remove iteration variable (item, index) from variable pool after iteration run completed | |||
| variable_pool.remove([self.node_id, 'index']) | |||
| variable_pool.remove([self.node_id, 'item']) | |||
| def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState): | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: IterationNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Set current iteration variable. | |||
| :variable_pool: variable pool | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| node_data = cast(IterationNodeData, self.node_data) | |||
| variable_mapping = { | |||
| f'{node_id}.input_selector': node_data.iterator_selector, | |||
| } | |||
| variable_pool.add((self.node_id, 'index'), state.index) | |||
| # get the iterator value | |||
| iterator = variable_pool.get_any(node_data.iterator_selector) | |||
| # init graph | |||
| iteration_graph = Graph.init( | |||
| graph_config=graph_config, | |||
| root_node_id=node_data.start_node_id | |||
| ) | |||
| if iterator is None or not isinstance(iterator, list): | |||
| return | |||
| if not iteration_graph: | |||
| raise ValueError('iteration graph not found') | |||
| if state.index < len(iterator): | |||
| variable_pool.add((self.node_id, 'item'), iterator[state.index]) | |||
| for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items(): | |||
| if sub_node_config.get('data', {}).get('iteration_id') != node_id: | |||
| continue | |||
| def _next_iteration(self, variable_pool: VariablePool, state: IterationState): | |||
| """ | |||
| Move to next iteration. | |||
| :param variable_pool: variable pool | |||
| """ | |||
| state.index += 1 | |||
| self._set_current_iteration_variable(variable_pool, state) | |||
| # variable selector to variable mapping | |||
| try: | |||
| # Get node class | |||
| from core.workflow.nodes.node_mapping import node_classes | |||
| node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type')) | |||
| node_cls = node_classes.get(node_type) | |||
| if not node_cls: | |||
| continue | |||
| def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState): | |||
| """ | |||
| Check if iteration limit is reached. | |||
| :return: True if iteration limit is reached, False otherwise | |||
| """ | |||
| node_data = cast(IterationNodeData, self.node_data) | |||
| iterator = variable_pool.get_any(node_data.iterator_selector) | |||
| node_cls = cast(BaseNode, node_cls) | |||
| sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( | |||
| graph_config=graph_config, | |||
| config=sub_node_config | |||
| ) | |||
| sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) | |||
| except NotImplementedError: | |||
| sub_node_variable_mapping = {} | |||
| if iterator is None or not isinstance(iterator, list): | |||
| return True | |||
| # remove iteration variables | |||
| sub_node_variable_mapping = { | |||
| sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items() | |||
| if value[0] != node_id | |||
| } | |||
| return state.index >= len(iterator) | |||
| def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState): | |||
| """ | |||
| Resolve current output. | |||
| :param variable_pool: variable pool | |||
| """ | |||
| output_selector = cast(IterationNodeData, self.node_data).output_selector | |||
| output = variable_pool.get_any(output_selector) | |||
| # clear the output for this iteration | |||
| variable_pool.remove([self.node_id] + output_selector[1:]) | |||
| state.current_output = output | |||
| if output is not None: | |||
| # NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration). | |||
| if isinstance(output, list): | |||
| state.outputs.extend(output) | |||
| else: | |||
| state.outputs.append(output) | |||
| variable_mapping.update(sub_node_variable_mapping) | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| return { | |||
| 'input_selector': node_data.iterator_selector, | |||
| } | |||
| # remove variable out from iteration | |||
| variable_mapping = { | |||
| key: value for key, value in variable_mapping.items() | |||
| if value[0] not in iteration_graph.node_ids | |||
| } | |||
| return variable_mapping | |||
| @@ -0,0 +1,39 @@ | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class IterationStartNode(BaseNode): | |||
| """ | |||
| Iteration Start Node. | |||
| """ | |||
| _node_data_cls = IterationStartNodeData | |||
| _node_type = NodeType.ITERATION_START | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run the node. | |||
| """ | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED | |||
| ) | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: IterationNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| return {} | |||
| @@ -1,3 +1,5 @@ | |||
| import logging | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, cast | |||
| from sqlalchemy import func | |||
| @@ -12,15 +14,15 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| logger = logging.getLogger(__name__) | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| @@ -37,11 +39,11 @@ class KnowledgeRetrievalNode(BaseNode): | |||
| _node_data_cls = KnowledgeRetrievalNodeData | |||
| node_type = NodeType.KNOWLEDGE_RETRIEVAL | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) | |||
| def _run(self) -> NodeRunResult: | |||
| node_data = cast(KnowledgeRetrievalNodeData, self.node_data) | |||
| # extract variables | |||
| variable = variable_pool.get_any(node_data.query_variable_selector) | |||
| variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector) | |||
| query = variable | |||
| variables = { | |||
| 'query': query | |||
| @@ -68,7 +70,7 @@ class KnowledgeRetrievalNode(BaseNode): | |||
| ) | |||
| except Exception as e: | |||
| logger.exception("Error when running knowledge retrieval node") | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=variables, | |||
| @@ -235,11 +237,21 @@ class KnowledgeRetrievalNode(BaseNode): | |||
| return context_list | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: | |||
| node_data = node_data | |||
| node_data = cast(cls._node_data_cls, node_data) | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: KnowledgeRetrievalNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| variable_mapping = {} | |||
| variable_mapping['query'] = node_data.query_variable_selector | |||
| variable_mapping[node_id + '.query'] = node_data.query_variable_selector | |||
| return variable_mapping | |||
| def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ | |||
| @@ -1,16 +1,17 @@ | |||
| import json | |||
| from collections.abc import Generator | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from copy import deepcopy | |||
| from typing import TYPE_CHECKING, Optional, cast | |||
| from typing import TYPE_CHECKING, Any, Optional, cast | |||
| from pydantic import BaseModel | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.app.entities.queue_entities import QueueRetrieverResourcesEvent | |||
| from core.entities.model_entities import ModelStatus | |||
| from core.entities.provider_entities import QuotaUnit | |||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||
| from core.model_runtime.entities.message_entities import ( | |||
| ImagePromptMessageContent, | |||
| PromptMessage, | |||
| @@ -25,7 +26,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.event import InNodeEvent | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent | |||
| from core.workflow.nodes.llm.entities import ( | |||
| LLMNodeChatModelMessage, | |||
| LLMNodeCompletionModelPromptTemplate, | |||
| @@ -43,17 +46,26 @@ if TYPE_CHECKING: | |||
| class ModelInvokeCompleted(BaseModel): | |||
| """ | |||
| Model invoke completed | |||
| """ | |||
| text: str | |||
| usage: LLMUsage | |||
| finish_reason: Optional[str] = None | |||
| class LLMNode(BaseNode): | |||
| _node_data_cls = LLMNodeData | |||
| _node_type = NodeType.LLM | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: | |||
| """ | |||
| Run node | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| node_data = cast(LLMNodeData, deepcopy(self.node_data)) | |||
| variable_pool = self.graph_runtime_state.variable_pool | |||
| node_inputs = None | |||
| process_data = None | |||
| @@ -80,10 +92,15 @@ class LLMNode(BaseNode): | |||
| node_inputs['#files#'] = [file.to_dict() for file in files] | |||
| # fetch context value | |||
| context = self._fetch_context(node_data, variable_pool) | |||
| generator = self._fetch_context(node_data, variable_pool) | |||
| context = None | |||
| for event in generator: | |||
| if isinstance(event, RunRetrieverResourceEvent): | |||
| context = event.context | |||
| yield event | |||
| if context: | |||
| node_inputs['#context#'] = context | |||
| node_inputs['#context#'] = context # type: ignore | |||
| # fetch model config | |||
| model_instance, model_config = self._fetch_model_config(node_data.model) | |||
| @@ -115,19 +132,34 @@ class LLMNode(BaseNode): | |||
| } | |||
| # handle invoke result | |||
| result_text, usage, finish_reason = self._invoke_llm( | |||
| generator = self._invoke_llm( | |||
| node_data_model=node_data.model, | |||
| model_instance=model_instance, | |||
| prompt_messages=prompt_messages, | |||
| stop=stop | |||
| ) | |||
| result_text = '' | |||
| usage = LLMUsage.empty_usage() | |||
| finish_reason = None | |||
| for event in generator: | |||
| if isinstance(event, RunStreamChunkEvent): | |||
| yield event | |||
| elif isinstance(event, ModelInvokeCompleted): | |||
| result_text = event.text | |||
| usage = event.usage | |||
| finish_reason = event.finish_reason | |||
| break | |||
| except Exception as e: | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=str(e), | |||
| inputs=node_inputs, | |||
| process_data=process_data | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=str(e), | |||
| inputs=node_inputs, | |||
| process_data=process_data | |||
| ) | |||
| ) | |||
| return | |||
| outputs = { | |||
| 'text': result_text, | |||
| @@ -135,22 +167,26 @@ class LLMNode(BaseNode): | |||
| 'finish_reason': finish_reason | |||
| } | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=node_inputs, | |||
| process_data=process_data, | |||
| outputs=outputs, | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| NodeRunMetadataKey.CURRENCY: usage.currency | |||
| } | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=node_inputs, | |||
| process_data=process_data, | |||
| outputs=outputs, | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| NodeRunMetadataKey.CURRENCY: usage.currency | |||
| }, | |||
| llm_usage=usage | |||
| ) | |||
| ) | |||
| def _invoke_llm(self, node_data_model: ModelConfig, | |||
| model_instance: ModelInstance, | |||
| prompt_messages: list[PromptMessage], | |||
| stop: list[str]) -> tuple[str, LLMUsage]: | |||
| stop: Optional[list[str]] = None) \ | |||
| -> Generator[RunEvent | ModelInvokeCompleted, None, None]: | |||
| """ | |||
| Invoke large language model | |||
| :param node_data_model: node data model | |||
| @@ -170,23 +206,31 @@ class LLMNode(BaseNode): | |||
| ) | |||
| # handle invoke result | |||
| text, usage, finish_reason = self._handle_invoke_result( | |||
| generator = self._handle_invoke_result( | |||
| invoke_result=invoke_result | |||
| ) | |||
| usage = LLMUsage.empty_usage() | |||
| for event in generator: | |||
| yield event | |||
| if isinstance(event, ModelInvokeCompleted): | |||
| usage = event.usage | |||
| # deduct quota | |||
| self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) | |||
| return text, usage, finish_reason | |||
| def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: | |||
| def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ | |||
| -> Generator[RunEvent | ModelInvokeCompleted, None, None]: | |||
| """ | |||
| Handle invoke result | |||
| :param invoke_result: invoke result | |||
| :return: | |||
| """ | |||
| if isinstance(invoke_result, LLMResult): | |||
| return | |||
| model = None | |||
| prompt_messages = [] | |||
| prompt_messages: list[PromptMessage] = [] | |||
| full_text = '' | |||
| usage = None | |||
| finish_reason = None | |||
| @@ -194,7 +238,10 @@ class LLMNode(BaseNode): | |||
| text = result.delta.message.content | |||
| full_text += text | |||
| self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) | |||
| yield RunStreamChunkEvent( | |||
| chunk_content=text, | |||
| from_variable_selector=[self.node_id, 'text'] | |||
| ) | |||
| if not model: | |||
| model = result.model | |||
| @@ -211,11 +258,15 @@ class LLMNode(BaseNode): | |||
| if not usage: | |||
| usage = LLMUsage.empty_usage() | |||
| return full_text, usage, finish_reason | |||
| yield ModelInvokeCompleted( | |||
| text=full_text, | |||
| usage=usage, | |||
| finish_reason=finish_reason | |||
| ) | |||
| def _transform_chat_messages(self, | |||
| messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate | |||
| ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: | |||
| messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate | |||
| ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: | |||
| """ | |||
| Transform chat messages | |||
| @@ -224,13 +275,13 @@ class LLMNode(BaseNode): | |||
| """ | |||
| if isinstance(messages, LLMNodeCompletionModelPromptTemplate): | |||
| if messages.edition_type == 'jinja2': | |||
| if messages.edition_type == 'jinja2' and messages.jinja2_text: | |||
| messages.text = messages.jinja2_text | |||
| return messages | |||
| for message in messages: | |||
| if message.edition_type == 'jinja2': | |||
| if message.edition_type == 'jinja2' and message.jinja2_text: | |||
| message.text = message.jinja2_text | |||
| return messages | |||
| @@ -348,7 +399,7 @@ class LLMNode(BaseNode): | |||
| return files | |||
| def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]: | |||
| def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]: | |||
| """ | |||
| Fetch context | |||
| :param node_data: node data | |||
| @@ -356,15 +407,18 @@ class LLMNode(BaseNode): | |||
| :return: | |||
| """ | |||
| if not node_data.context.enabled: | |||
| return None | |||
| return | |||
| if not node_data.context.variable_selector: | |||
| return None | |||
| return | |||
| context_value = variable_pool.get_any(node_data.context.variable_selector) | |||
| if context_value: | |||
| if isinstance(context_value, str): | |||
| return context_value | |||
| yield RunRetrieverResourceEvent( | |||
| retriever_resources=[], | |||
| context=context_value | |||
| ) | |||
| elif isinstance(context_value, list): | |||
| context_str = '' | |||
| original_retriever_resource = [] | |||
| @@ -381,17 +435,10 @@ class LLMNode(BaseNode): | |||
| if retriever_resource: | |||
| original_retriever_resource.append(retriever_resource) | |||
| if self.callbacks and original_retriever_resource: | |||
| for callback in self.callbacks: | |||
| callback.on_event( | |||
| event=QueueRetrieverResourcesEvent( | |||
| retriever_resources=original_retriever_resource | |||
| ) | |||
| ) | |||
| return context_str.strip() | |||
| return None | |||
| yield RunRetrieverResourceEvent( | |||
| retriever_resources=original_retriever_resource, | |||
| context=context_str.strip() | |||
| ) | |||
| def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: | |||
| """ | |||
| @@ -574,7 +621,8 @@ class LLMNode(BaseNode): | |||
| if not isinstance(prompt_message.content, str): | |||
| prompt_message_content = [] | |||
| for content_item in prompt_message.content: | |||
| if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent): | |||
| if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance( | |||
| content_item, ImagePromptMessageContent): | |||
| # Override vision config if LLM node has vision config | |||
| if vision_detail: | |||
| content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) | |||
| @@ -646,13 +694,19 @@ class LLMNode(BaseNode): | |||
| db.session.commit() | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: LLMNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| prompt_template = node_data.prompt_template | |||
| variable_selectors = [] | |||
| @@ -702,6 +756,10 @@ class LLMNode(BaseNode): | |||
| for variable_selector in node_data.prompt_config.jinja2_variables or []: | |||
| variable_mapping[variable_selector.variable] = variable_selector.value_selector | |||
| variable_mapping = { | |||
| node_id + '.' + key: value for key, value in variable_mapping.items() | |||
| } | |||
| return variable_mapping | |||
| @classmethod | |||
| @@ -1,20 +1,34 @@ | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseIterationNode | |||
| from typing import Any | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.loop.entities import LoopNodeData, LoopState | |||
| from core.workflow.utils.condition.entities import Condition | |||
| class LoopNode(BaseIterationNode): | |||
| class LoopNode(BaseNode): | |||
| """ | |||
| Loop Node. | |||
| """ | |||
| _node_data_cls = LoopNodeData | |||
| _node_type = NodeType.LOOP | |||
| def _run(self, variable_pool: VariablePool) -> LoopState: | |||
| return super()._run(variable_pool) | |||
| def _run(self) -> LoopState: | |||
| return super()._run() | |||
| def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str: | |||
| @classmethod | |||
| def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: | |||
| """ | |||
| Get next iteration start node id based on the graph. | |||
| Get conditions. | |||
| """ | |||
| node_id = node_config.get('id') | |||
| if not node_id: | |||
| return [] | |||
| # TODO waiting for implementation | |||
| return [Condition( | |||
| variable_selector=[node_id, 'index'], | |||
| comparison_operator="≤", | |||
| value_type="value_selector", | |||
| value_selector=[] | |||
| )] | |||
| @@ -0,0 +1,37 @@ | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.nodes.answer.answer_node import AnswerNode | |||
| from core.workflow.nodes.code.code_node import CodeNode | |||
| from core.workflow.nodes.end.end_node import EndNode | |||
| from core.workflow.nodes.http_request.http_request_node import HttpRequestNode | |||
| from core.workflow.nodes.if_else.if_else_node import IfElseNode | |||
| from core.workflow.nodes.iteration.iteration_node import IterationNode | |||
| from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode | |||
| from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode | |||
| from core.workflow.nodes.llm.llm_node import LLMNode | |||
| from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode | |||
| from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode | |||
| from core.workflow.nodes.start.start_node import StartNode | |||
| from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode | |||
| from core.workflow.nodes.tool.tool_node import ToolNode | |||
| from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode | |||
| from core.workflow.nodes.variable_assigner import VariableAssignerNode | |||
| node_classes = { | |||
| NodeType.START: StartNode, | |||
| NodeType.END: EndNode, | |||
| NodeType.ANSWER: AnswerNode, | |||
| NodeType.LLM: LLMNode, | |||
| NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, | |||
| NodeType.IF_ELSE: IfElseNode, | |||
| NodeType.CODE: CodeNode, | |||
| NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, | |||
| NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, | |||
| NodeType.HTTP_REQUEST: HttpRequestNode, | |||
| NodeType.TOOL: ToolNode, | |||
| NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, | |||
| NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR | |||
| NodeType.ITERATION: IterationNode, | |||
| NodeType.ITERATION_START: IterationStartNode, | |||
| NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, | |||
| NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| import json | |||
| import uuid | |||
| from typing import Optional, cast | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Optional, cast | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| @@ -66,12 +67,12 @@ class ParameterExtractorNode(LLMNode): | |||
| } | |||
| } | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run the node. | |||
| """ | |||
| node_data = cast(ParameterExtractorNodeData, self.node_data) | |||
| variable = variable_pool.get_any(node_data.query) | |||
| variable = self.graph_runtime_state.variable_pool.get_any(node_data.query) | |||
| if not variable: | |||
| raise ValueError("Input variable content not found or is empty") | |||
| query = variable | |||
| @@ -92,17 +93,20 @@ class ParameterExtractorNode(LLMNode): | |||
| raise ValueError("Model schema not found") | |||
| # fetch memory | |||
| memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) | |||
| memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance) | |||
| if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ | |||
| and node_data.reasoning_mode == 'function_call': | |||
| # use function call | |||
| prompt_messages, prompt_message_tools = self._generate_function_call_prompt( | |||
| node_data, query, variable_pool, model_config, memory | |||
| node_data, query, self.graph_runtime_state.variable_pool, model_config, memory | |||
| ) | |||
| else: | |||
| # use prompt engineering | |||
| prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, | |||
| prompt_messages = self._generate_prompt_engineering_prompt(node_data, | |||
| query, | |||
| self.graph_runtime_state.variable_pool, | |||
| model_config, | |||
| memory) | |||
| prompt_message_tools = [] | |||
| @@ -172,7 +176,8 @@ class ParameterExtractorNode(LLMNode): | |||
| NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| NodeRunMetadataKey.CURRENCY: usage.currency | |||
| } | |||
| }, | |||
| llm_usage=usage | |||
| ) | |||
| def _invoke_llm(self, node_data_model: ModelConfig, | |||
| @@ -697,15 +702,19 @@ class ParameterExtractorNode(LLMNode): | |||
| return self._model_instance, self._model_config | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[ | |||
| str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: ParameterExtractorNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| node_data = node_data | |||
| variable_mapping = { | |||
| 'query': node_data.query | |||
| } | |||
| @@ -715,4 +724,8 @@ class ParameterExtractorNode(LLMNode): | |||
| for selector in variable_template_parser.extract_variable_selectors(): | |||
| variable_mapping[selector.variable] = selector.value_selector | |||
| variable_mapping = { | |||
| node_id + '.' + key: value for key, value in variable_mapping.items() | |||
| } | |||
| return variable_mapping | |||
| @@ -1,10 +1,12 @@ | |||
| import json | |||
| import logging | |||
| from typing import Optional, Union, cast | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Optional, Union, cast | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole | |||
| from core.model_runtime.entities.model_entities import ModelPropertyKey | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| @@ -13,10 +15,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp | |||
| from core.prompt.simple_prompt_transform import ModelMode | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.llm.llm_node import LLMNode | |||
| from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted | |||
| from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData | |||
| from core.workflow.nodes.question_classifier.template_prompts import ( | |||
| QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, | |||
| @@ -36,9 +37,10 @@ class QuestionClassifierNode(LLMNode): | |||
| _node_data_cls = QuestionClassifierNodeData | |||
| node_type = NodeType.QUESTION_CLASSIFIER | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) | |||
| node_data = cast(QuestionClassifierNodeData, node_data) | |||
| variable_pool = self.graph_runtime_state.variable_pool | |||
| # extract variables | |||
| variable = variable_pool.get(node_data.query_variable_selector) | |||
| @@ -63,12 +65,23 @@ class QuestionClassifierNode(LLMNode): | |||
| ) | |||
| # handle invoke result | |||
| result_text, usage, finish_reason = self._invoke_llm( | |||
| generator = self._invoke_llm( | |||
| node_data_model=node_data.model, | |||
| model_instance=model_instance, | |||
| prompt_messages=prompt_messages, | |||
| stop=stop | |||
| ) | |||
| result_text = '' | |||
| usage = LLMUsage.empty_usage() | |||
| finish_reason = None | |||
| for event in generator: | |||
| if isinstance(event, ModelInvokeCompleted): | |||
| result_text = event.text | |||
| usage = event.usage | |||
| finish_reason = event.finish_reason | |||
| break | |||
| category_name = node_data.classes[0].name | |||
| category_id = node_data.classes[0].id | |||
| try: | |||
| @@ -109,7 +122,8 @@ class QuestionClassifierNode(LLMNode): | |||
| NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| NodeRunMetadataKey.CURRENCY: usage.currency | |||
| } | |||
| }, | |||
| llm_usage=usage | |||
| ) | |||
| except ValueError as e: | |||
| @@ -121,13 +135,24 @@ class QuestionClassifierNode(LLMNode): | |||
| NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, | |||
| NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, | |||
| NodeRunMetadataKey.CURRENCY: usage.currency | |||
| } | |||
| }, | |||
| llm_usage=usage | |||
| ) | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: | |||
| node_data = node_data | |||
| node_data = cast(cls._node_data_cls, node_data) | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: QuestionClassifierNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| variable_mapping = {'query': node_data.query_variable_selector} | |||
| variable_selectors = [] | |||
| if node_data.instruction: | |||
| @@ -135,6 +160,11 @@ class QuestionClassifierNode(LLMNode): | |||
| variable_selectors.extend(variable_template_parser.extract_variable_selectors()) | |||
| for variable_selector in variable_selectors: | |||
| variable_mapping[variable_selector.variable] = variable_selector.value_selector | |||
| variable_mapping = { | |||
| node_id + '.' + key: value for key, value in variable_mapping.items() | |||
| } | |||
| return variable_mapping | |||
| @classmethod | |||
| @@ -1,7 +1,9 @@ | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool | |||
| from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.start.entities import StartNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| @@ -11,14 +13,13 @@ class StartNode(BaseNode): | |||
| _node_data_cls = StartNodeData | |||
| _node_type = NodeType.START | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run node | |||
| :param variable_pool: variable pool | |||
| :return: | |||
| """ | |||
| node_inputs = dict(variable_pool.user_inputs) | |||
| system_inputs = variable_pool.system_variables | |||
| node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) | |||
| system_inputs = self.graph_runtime_state.variable_pool.system_variables | |||
| for var in system_inputs: | |||
| node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var] | |||
| @@ -30,9 +31,16 @@ class StartNode(BaseNode): | |||
| ) | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: StartNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| @@ -1,15 +1,16 @@ | |||
| import os | |||
| from typing import Optional, cast | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Optional, cast | |||
| from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000')) | |||
| class TemplateTransformNode(BaseNode): | |||
| _node_data_cls = TemplateTransformNodeData | |||
| _node_type = NodeType.TEMPLATE_TRANSFORM | |||
| @@ -34,7 +35,7 @@ class TemplateTransformNode(BaseNode): | |||
| } | |||
| } | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run node | |||
| """ | |||
| @@ -45,7 +46,7 @@ class TemplateTransformNode(BaseNode): | |||
| variables = {} | |||
| for variable_selector in node_data.variables: | |||
| variable_name = variable_selector.variable | |||
| value = variable_pool.get_any(variable_selector.value_selector) | |||
| value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) | |||
| variables[variable_name] = value | |||
| # Run code | |||
| try: | |||
| @@ -60,7 +61,7 @@ class TemplateTransformNode(BaseNode): | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=str(e) | |||
| ) | |||
| if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: | |||
| return NodeRunResult( | |||
| inputs=variables, | |||
| @@ -75,14 +76,21 @@ class TemplateTransformNode(BaseNode): | |||
| 'output': result['result'] | |||
| } | |||
| ) | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: TemplateTransformNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| return { | |||
| variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables | |||
| } | |||
| node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables | |||
| } | |||
| @@ -26,7 +26,7 @@ class ToolNode(BaseNode): | |||
| _node_data_cls = ToolNodeData | |||
| _node_type = NodeType.TOOL | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run the tool node | |||
| """ | |||
| @@ -56,8 +56,8 @@ class ToolNode(BaseNode): | |||
| # get parameters | |||
| tool_parameters = tool_runtime.get_runtime_parameters() or [] | |||
| parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data) | |||
| parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True) | |||
| parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data) | |||
| parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True) | |||
| try: | |||
| messages = ToolEngine.workflow_invoke( | |||
| @@ -66,6 +66,7 @@ class ToolNode(BaseNode): | |||
| user_id=self.user_id, | |||
| workflow_tool_callback=DifyWorkflowCallbackHandler(), | |||
| workflow_call_depth=self.workflow_call_depth, | |||
| thread_pool_id=self.thread_pool_id, | |||
| ) | |||
| except Exception as e: | |||
| return NodeRunResult( | |||
| @@ -145,7 +146,8 @@ class ToolNode(BaseNode): | |||
| assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) | |||
| return list(variable.value) if variable else [] | |||
| def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): | |||
| def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\ | |||
| -> tuple[str, list[FileVar], list[dict]]: | |||
| """ | |||
| Convert ToolInvokeMessages into tuple[plain_text, files] | |||
| """ | |||
| @@ -221,9 +223,16 @@ class ToolNode(BaseNode): | |||
| return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON] | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: ToolNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| @@ -239,4 +248,8 @@ class ToolNode(BaseNode): | |||
| elif input.type == 'constant': | |||
| pass | |||
| result = { | |||
| node_id + '.' + key: value for key, value in result.items() | |||
| } | |||
| return result | |||
| @@ -1,8 +1,7 @@ | |||
| from typing import cast | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, cast | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| @@ -12,7 +11,7 @@ class VariableAggregatorNode(BaseNode): | |||
| _node_data_cls = VariableAssignerNodeData | |||
| _node_type = NodeType.VARIABLE_AGGREGATOR | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| node_data = cast(VariableAssignerNodeData, self.node_data) | |||
| # Get variables | |||
| outputs = {} | |||
| @@ -20,7 +19,7 @@ class VariableAggregatorNode(BaseNode): | |||
| if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled: | |||
| for selector in node_data.variables: | |||
| variable = variable_pool.get_any(selector) | |||
| variable = self.graph_runtime_state.variable_pool.get_any(selector) | |||
| if variable is not None: | |||
| outputs = { | |||
| "output": variable | |||
| @@ -33,7 +32,7 @@ class VariableAggregatorNode(BaseNode): | |||
| else: | |||
| for group in node_data.advanced_settings.groups: | |||
| for selector in group.variables: | |||
| variable = variable_pool.get_any(selector) | |||
| variable = self.graph_runtime_state.variable_pool.get_any(selector) | |||
| if variable is not None: | |||
| outputs[group.group_name] = { | |||
| @@ -49,5 +48,17 @@ class VariableAggregatorNode(BaseNode): | |||
| ) | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: VariableAssignerNodeData | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| return {} | |||
| @@ -6,7 +6,6 @@ from sqlalchemy.orm import Session | |||
| from core.app.segments import SegmentType, Variable, factory | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from extensions.ext_database import db | |||
| from models import ConversationVariable, WorkflowNodeExecutionStatus | |||
| @@ -19,23 +18,23 @@ class VariableAssignerNode(BaseNode): | |||
| _node_data_cls: type[BaseNodeData] = VariableAssignerData | |||
| _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER | |||
| def _run(self, variable_pool: VariablePool) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: | |||
| data = cast(VariableAssignerData, self.node_data) | |||
| # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject | |||
| original_variable = variable_pool.get(data.assigned_variable_selector) | |||
| original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector) | |||
| if not isinstance(original_variable, Variable): | |||
| raise VariableAssignerNodeError('assigned variable not found') | |||
| match data.write_mode: | |||
| case WriteMode.OVER_WRITE: | |||
| income_value = variable_pool.get(data.input_variable_selector) | |||
| income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) | |||
| if not income_value: | |||
| raise VariableAssignerNodeError('input value not found') | |||
| updated_variable = original_variable.model_copy(update={'value': income_value.value}) | |||
| case WriteMode.APPEND: | |||
| income_value = variable_pool.get(data.input_variable_selector) | |||
| income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) | |||
| if not income_value: | |||
| raise VariableAssignerNodeError('input value not found') | |||
| updated_value = original_variable.value + [income_value.value] | |||
| @@ -49,11 +48,11 @@ class VariableAssignerNode(BaseNode): | |||
| raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') | |||
| # Over write the variable. | |||
| variable_pool.add(data.assigned_variable_selector, updated_variable) | |||
| self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable) | |||
| # TODO: Move database operation to the pipeline. | |||
| # Update conversation variable. | |||
| conversation_id = variable_pool.get(['sys', 'conversation_id']) | |||
| conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id']) | |||
| if not conversation_id: | |||
| raise VariableAssignerNodeError('conversation_id not found') | |||
| update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) | |||
| @@ -0,0 +1,17 @@ | |||
| from typing import Literal, Optional | |||
| from pydantic import BaseModel | |||
| class Condition(BaseModel): | |||
| """ | |||
| Condition entity | |||
| """ | |||
| variable_selector: list[str] | |||
| comparison_operator: Literal[ | |||
| # for string or array | |||
| "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", | |||
| # for number | |||
| "=", "≠", ">", "<", "≥", "≤", "null", "not null" | |||
| ] | |||
| value: Optional[str] = None | |||
| @@ -0,0 +1,383 @@ | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional | |||
| from core.file.file_obj import FileVar | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.utils.condition.entities import Condition | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| class ConditionProcessor: | |||
| def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): | |||
| input_conditions = [] | |||
| group_result = [] | |||
| index = 0 | |||
| for condition in conditions: | |||
| index += 1 | |||
| actual_value = variable_pool.get_any( | |||
| condition.variable_selector | |||
| ) | |||
| expected_value = None | |||
| if condition.value is not None: | |||
| variable_template_parser = VariableTemplateParser(template=condition.value) | |||
| variable_selectors = variable_template_parser.extract_variable_selectors() | |||
| if variable_selectors: | |||
| for variable_selector in variable_selectors: | |||
| value = variable_pool.get_any( | |||
| variable_selector.value_selector | |||
| ) | |||
| expected_value = variable_template_parser.format({variable_selector.variable: value}) | |||
| if expected_value is None: | |||
| expected_value = condition.value | |||
| else: | |||
| expected_value = condition.value | |||
| comparison_operator = condition.comparison_operator | |||
| input_conditions.append( | |||
| { | |||
| "actual_value": actual_value, | |||
| "expected_value": expected_value, | |||
| "comparison_operator": comparison_operator | |||
| } | |||
| ) | |||
| result = self.evaluate_condition(actual_value, comparison_operator, expected_value) | |||
| group_result.append(result) | |||
| return input_conditions, group_result | |||
| def evaluate_condition( | |||
| self, | |||
| actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], | |||
| comparison_operator: str, | |||
| expected_value: Optional[str] = None | |||
| ) -> bool: | |||
| """ | |||
| Evaluate condition | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :param comparison_operator: comparison operator | |||
| :return: bool | |||
| """ | |||
| if comparison_operator == "contains": | |||
| return self._assert_contains(actual_value, expected_value) | |||
| elif comparison_operator == "not contains": | |||
| return self._assert_not_contains(actual_value, expected_value) | |||
| elif comparison_operator == "start with": | |||
| return self._assert_start_with(actual_value, expected_value) | |||
| elif comparison_operator == "end with": | |||
| return self._assert_end_with(actual_value, expected_value) | |||
| elif comparison_operator == "is": | |||
| return self._assert_is(actual_value, expected_value) | |||
| elif comparison_operator == "is not": | |||
| return self._assert_is_not(actual_value, expected_value) | |||
| elif comparison_operator == "empty": | |||
| return self._assert_empty(actual_value) | |||
| elif comparison_operator == "not empty": | |||
| return self._assert_not_empty(actual_value) | |||
| elif comparison_operator == "=": | |||
| return self._assert_equal(actual_value, expected_value) | |||
| elif comparison_operator == "≠": | |||
| return self._assert_not_equal(actual_value, expected_value) | |||
| elif comparison_operator == ">": | |||
| return self._assert_greater_than(actual_value, expected_value) | |||
| elif comparison_operator == "<": | |||
| return self._assert_less_than(actual_value, expected_value) | |||
| elif comparison_operator == "≥": | |||
| return self._assert_greater_than_or_equal(actual_value, expected_value) | |||
| elif comparison_operator == "≤": | |||
| return self._assert_less_than_or_equal(actual_value, expected_value) | |||
| elif comparison_operator == "null": | |||
| return self._assert_null(actual_value) | |||
| elif comparison_operator == "not null": | |||
| return self._assert_not_null(actual_value) | |||
| else: | |||
| raise ValueError(f"Invalid comparison operator: {comparison_operator}") | |||
| def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: | |||
| """ | |||
| Assert contains | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if not actual_value: | |||
| return False | |||
| if not isinstance(actual_value, str | list): | |||
| raise ValueError('Invalid actual value type: string or array') | |||
| if expected_value not in actual_value: | |||
| return False | |||
| return True | |||
| def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: | |||
| """ | |||
| Assert not contains | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if not actual_value: | |||
| return True | |||
| if not isinstance(actual_value, str | list): | |||
| raise ValueError('Invalid actual value type: string or array') | |||
| if expected_value in actual_value: | |||
| return False | |||
| return True | |||
| def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: | |||
| """ | |||
| Assert start with | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if not actual_value: | |||
| return False | |||
| if not isinstance(actual_value, str): | |||
| raise ValueError('Invalid actual value type: string') | |||
| if not actual_value.startswith(expected_value): | |||
| return False | |||
| return True | |||
| def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: | |||
| """ | |||
| Assert end with | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if not actual_value: | |||
| return False | |||
| if not isinstance(actual_value, str): | |||
| raise ValueError('Invalid actual value type: string') | |||
| if not actual_value.endswith(expected_value): | |||
| return False | |||
| return True | |||
| def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: | |||
| """ | |||
| Assert is | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, str): | |||
| raise ValueError('Invalid actual value type: string') | |||
| if actual_value != expected_value: | |||
| return False | |||
| return True | |||
| def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: | |||
| """ | |||
| Assert is not | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, str): | |||
| raise ValueError('Invalid actual value type: string') | |||
| if actual_value == expected_value: | |||
| return False | |||
| return True | |||
| def _assert_empty(self, actual_value: Optional[str]) -> bool: | |||
| """ | |||
| Assert empty | |||
| :param actual_value: actual value | |||
| :return: | |||
| """ | |||
| if not actual_value: | |||
| return True | |||
| return False | |||
| def _assert_not_empty(self, actual_value: Optional[str]) -> bool: | |||
| """ | |||
| Assert not empty | |||
| :param actual_value: actual value | |||
| :return: | |||
| """ | |||
| if actual_value: | |||
| return True | |||
| return False | |||
| def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: | |||
| """ | |||
| Assert equal | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value != expected_value: | |||
| return False | |||
| return True | |||
| def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: | |||
| """ | |||
| Assert not equal | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value == expected_value: | |||
| return False | |||
| return True | |||
| def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: | |||
| """ | |||
| Assert greater than | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value <= expected_value: | |||
| return False | |||
| return True | |||
| def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: | |||
| """ | |||
| Assert less than | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value >= expected_value: | |||
| return False | |||
| return True | |||
| def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], | |||
| expected_value: str | int | float) -> bool: | |||
| """ | |||
| Assert greater than or equal | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value < expected_value: | |||
| return False | |||
| return True | |||
| def _assert_less_than_or_equal(self, actual_value: Optional[int | float], | |||
| expected_value: str | int | float) -> bool: | |||
| """ | |||
| Assert less than or equal | |||
| :param actual_value: actual value | |||
| :param expected_value: expected value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return False | |||
| if not isinstance(actual_value, int | float): | |||
| raise ValueError('Invalid actual value type: number') | |||
| if isinstance(actual_value, int): | |||
| expected_value = int(expected_value) | |||
| else: | |||
| expected_value = float(expected_value) | |||
| if actual_value > expected_value: | |||
| return False | |||
| return True | |||
| def _assert_null(self, actual_value: Optional[int | float]) -> bool: | |||
| """ | |||
| Assert null | |||
| :param actual_value: actual value | |||
| :return: | |||
| """ | |||
| if actual_value is None: | |||
| return True | |||
| return False | |||
| def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: | |||
| """ | |||
| Assert not null | |||
| :param actual_value: actual value | |||
| :return: | |||
| """ | |||
| if actual_value is not None: | |||
| return True | |||
| return False | |||
| class ConditionAssertionError(Exception): | |||
| def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None: | |||
| self.message = message | |||
| self.conditions = conditions | |||
| self.sub_condition_compare_results = sub_condition_compare_results | |||
| super().__init__(self.message) | |||
| @@ -0,0 +1,314 @@ | |||
| import logging | |||
| import time | |||
| import uuid | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Any, Optional, cast | |||
| from configs import dify_config | |||
| from core.app.app_config.entities import FileExtraConfig | |||
| from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.file.file_obj import FileTransferMethod, FileType, FileVar | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeType, UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.errors import WorkflowNodeRunFailedError | |||
| from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.graph_engine.graph_engine import GraphEngine | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.event import RunEvent | |||
| from core.workflow.nodes.llm.entities import LLMNodeData | |||
| from core.workflow.nodes.node_mapping import node_classes | |||
| from models.workflow import ( | |||
| Workflow, | |||
| WorkflowType, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| class WorkflowEntry: | |||
| def __init__( | |||
| self, | |||
| tenant_id: str, | |||
| app_id: str, | |||
| workflow_id: str, | |||
| workflow_type: WorkflowType, | |||
| graph_config: Mapping[str, Any], | |||
| graph: Graph, | |||
| user_id: str, | |||
| user_from: UserFrom, | |||
| invoke_from: InvokeFrom, | |||
| call_depth: int, | |||
| variable_pool: VariablePool, | |||
| thread_pool_id: Optional[str] = None | |||
| ) -> None: | |||
| """ | |||
| Init workflow entry | |||
| :param tenant_id: tenant id | |||
| :param app_id: app id | |||
| :param workflow_id: workflow id | |||
| :param workflow_type: workflow type | |||
| :param graph_config: workflow graph config | |||
| :param graph: workflow graph | |||
| :param user_id: user id | |||
| :param user_from: user from | |||
| :param invoke_from: invoke from | |||
| :param call_depth: call depth | |||
| :param variable_pool: variable pool | |||
| :param thread_pool_id: thread pool id | |||
| """ | |||
| # check call depth | |||
| workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH | |||
| if call_depth > workflow_call_max_depth: | |||
| raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) | |||
| # init workflow run state | |||
| self.graph_engine = GraphEngine( | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| workflow_type=workflow_type, | |||
| workflow_id=workflow_id, | |||
| user_id=user_id, | |||
| user_from=user_from, | |||
| invoke_from=invoke_from, | |||
| call_depth=call_depth, | |||
| graph=graph, | |||
| graph_config=graph_config, | |||
| variable_pool=variable_pool, | |||
| max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, | |||
| max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, | |||
| thread_pool_id=thread_pool_id | |||
| ) | |||
| def run( | |||
| self, | |||
| *, | |||
| callbacks: Sequence[WorkflowCallback], | |||
| ) -> Generator[GraphEngineEvent, None, None]: | |||
| """ | |||
| :param callbacks: workflow callbacks | |||
| """ | |||
| graph_engine = self.graph_engine | |||
| try: | |||
| # run workflow | |||
| generator = graph_engine.run() | |||
| for event in generator: | |||
| if callbacks: | |||
| for callback in callbacks: | |||
| callback.on_event( | |||
| event=event | |||
| ) | |||
| yield event | |||
| except GenerateTaskStoppedException: | |||
| pass | |||
| except Exception as e: | |||
| logger.exception("Unknown Error when workflow entry running") | |||
| if callbacks: | |||
| for callback in callbacks: | |||
| callback.on_event( | |||
| event=GraphRunFailedEvent( | |||
| error=str(e) | |||
| ) | |||
| ) | |||
| return | |||
| @classmethod | |||
| def single_step_run( | |||
| cls, | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user_id: str, | |||
| user_inputs: dict | |||
| ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: | |||
| """ | |||
| Single step run workflow node | |||
| :param workflow: Workflow instance | |||
| :param node_id: node id | |||
| :param user_id: user id | |||
| :param user_inputs: user inputs | |||
| :return: | |||
| """ | |||
| # fetch node info from workflow graph | |||
| graph = workflow.graph_dict | |||
| if not graph: | |||
| raise ValueError('workflow graph not found') | |||
| nodes = graph.get('nodes') | |||
| if not nodes: | |||
| raise ValueError('nodes not found in workflow graph') | |||
| # fetch node config from node id | |||
| node_config = None | |||
| for node in nodes: | |||
| if node.get('id') == node_id: | |||
| node_config = node | |||
| break | |||
| if not node_config: | |||
| raise ValueError('node id not found in workflow graph') | |||
| # Get node class | |||
| node_type = NodeType.value_of(node_config.get('data', {}).get('type')) | |||
| node_cls = node_classes.get(node_type) | |||
| node_cls = cast(type[BaseNode], node_cls) | |||
| if not node_cls: | |||
| raise ValueError(f'Node class not found for node type {node_type}') | |||
| # init variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables={}, | |||
| user_inputs={}, | |||
| environment_variables=workflow.environment_variables, | |||
| ) | |||
| # init graph | |||
| graph = Graph.init( | |||
| graph_config=workflow.graph_dict | |||
| ) | |||
| # init workflow run state | |||
| node_instance: BaseNode = node_cls( | |||
| id=str(uuid.uuid4()), | |||
| config=node_config, | |||
| graph_init_params=GraphInitParams( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=workflow.app_id, | |||
| workflow_type=WorkflowType.value_of(workflow.type), | |||
| workflow_id=workflow.id, | |||
| graph_config=workflow.graph_dict, | |||
| user_id=user_id, | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0 | |||
| ), | |||
| graph=graph, | |||
| graph_runtime_state=GraphRuntimeState( | |||
| variable_pool=variable_pool, | |||
| start_at=time.perf_counter() | |||
| ) | |||
| ) | |||
| try: | |||
| # variable selector to variable mapping | |||
| try: | |||
| variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( | |||
| graph_config=workflow.graph_dict, | |||
| config=node_config | |||
| ) | |||
| except NotImplementedError: | |||
| variable_mapping = {} | |||
| cls.mapping_user_inputs_to_variable_pool( | |||
| variable_mapping=variable_mapping, | |||
| user_inputs=user_inputs, | |||
| variable_pool=variable_pool, | |||
| tenant_id=workflow.tenant_id, | |||
| node_type=node_type, | |||
| node_data=node_instance.node_data | |||
| ) | |||
| # run node | |||
| generator = node_instance.run() | |||
| return node_instance, generator | |||
| except Exception as e: | |||
| raise WorkflowNodeRunFailedError( | |||
| node_instance=node_instance, | |||
| error=str(e) | |||
| ) | |||
| @classmethod | |||
| def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: | |||
| """ | |||
| Handle special values | |||
| :param value: value | |||
| :return: | |||
| """ | |||
| if not value: | |||
| return None | |||
| new_value = dict(value) if value else {} | |||
| if isinstance(new_value, dict): | |||
| for key, val in new_value.items(): | |||
| if isinstance(val, FileVar): | |||
| new_value[key] = val.to_dict() | |||
| elif isinstance(val, list): | |||
| new_val = [] | |||
| for v in val: | |||
| if isinstance(v, FileVar): | |||
| new_val.append(v.to_dict()) | |||
| else: | |||
| new_val.append(v) | |||
| new_value[key] = new_val | |||
| return new_value | |||
| @classmethod | |||
| def mapping_user_inputs_to_variable_pool( | |||
| cls, | |||
| variable_mapping: Mapping[str, Sequence[str]], | |||
| user_inputs: dict, | |||
| variable_pool: VariablePool, | |||
| tenant_id: str, | |||
| node_type: NodeType, | |||
| node_data: BaseNodeData | |||
| ) -> None: | |||
| for node_variable, variable_selector in variable_mapping.items(): | |||
| # fetch node id and variable key from node_variable | |||
| node_variable_list = node_variable.split('.') | |||
| if len(node_variable_list) < 1: | |||
| raise ValueError(f'Invalid node variable {node_variable}') | |||
| node_variable_key = '.'.join(node_variable_list[1:]) | |||
| if ( | |||
| node_variable_key not in user_inputs | |||
| and node_variable not in user_inputs | |||
| ) and not variable_pool.get(variable_selector): | |||
| raise ValueError(f'Variable key {node_variable} not found in user inputs.') | |||
| # fetch variable node id from variable selector | |||
| variable_node_id = variable_selector[0] | |||
| variable_key_list = variable_selector[1:] | |||
| variable_key_list = cast(list[str], variable_key_list) | |||
| # get input value | |||
| input_value = user_inputs.get(node_variable) | |||
| if not input_value: | |||
| input_value = user_inputs.get(node_variable_key) | |||
| # FIXME: temp fix for image type | |||
| if node_type == NodeType.LLM: | |||
| new_value = [] | |||
| if isinstance(input_value, list): | |||
| node_data = cast(LLMNodeData, node_data) | |||
| detail = node_data.vision.configs.detail if node_data.vision.configs else None | |||
| for item in input_value: | |||
| if isinstance(item, dict) and 'type' in item and item['type'] == 'image': | |||
| transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) | |||
| file = FileVar( | |||
| tenant_id=tenant_id, | |||
| type=FileType.IMAGE, | |||
| transfer_method=transfer_method, | |||
| url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, | |||
| related_id=item.get( | |||
| 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, | |||
| extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), | |||
| ) | |||
| new_value.append(file) | |||
| if new_value: | |||
| value = new_value | |||
| # append variable and value to variable pool | |||
| variable_pool.add([variable_node_id] + variable_key_list, input_value) | |||
| @@ -0,0 +1,35 @@ | |||
| """add node_execution_id into node_executions | |||
| Revision ID: 675b5321501b | |||
| Revises: 030f4915f36a | |||
| Create Date: 2024-08-12 10:54:02.259331 | |||
| """ | |||
| import sqlalchemy as sa | |||
| from alembic import op | |||
| import models as models | |||
| # revision identifiers, used by Alembic. | |||
| revision = '675b5321501b' | |||
| down_revision = '030f4915f36a' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('node_execution_id', sa.String(length=255), nullable=True)) | |||
| batch_op.create_index('workflow_node_execution_id_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_execution_id'], unique=False) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: | |||
| batch_op.drop_index('workflow_node_execution_id_idx') | |||
| batch_op.drop_column('node_execution_id') | |||
| # ### end Alembic commands ### | |||
| @@ -581,6 +581,8 @@ class WorkflowNodeExecution(db.Model): | |||
| 'triggered_from', 'workflow_run_id'), | |||
| db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', | |||
| 'triggered_from', 'node_id'), | |||
| db.Index('workflow_node_execution_id_idx', 'tenant_id', 'app_id', 'workflow_id', | |||
| 'triggered_from', 'node_execution_id'), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| @@ -591,6 +593,7 @@ class WorkflowNodeExecution(db.Model): | |||
| workflow_run_id = db.Column(StringUUID) | |||
| index = db.Column(db.Integer, nullable=False) | |||
| predecessor_node_id = db.Column(db.String(255)) | |||
| node_execution_id = db.Column(db.String(255), nullable=True) | |||
| node_id = db.Column(db.String(255), nullable=False) | |||
| node_type = db.Column(db.String(255), nullable=False) | |||
| title = db.Column(db.String(255), nullable=False) | |||
| @@ -13,8 +13,9 @@ from services.workflow_service import WorkflowService | |||
| logger = logging.getLogger(__name__) | |||
| current_dsl_version = "0.1.1" | |||
| current_dsl_version = "0.1.2" | |||
| dsl_to_dify_version_mapping: dict[str, str] = { | |||
| "0.1.2": "0.8.0", | |||
| "0.1.1": "0.6.0", # dsl version -> from dify version | |||
| } | |||
| @@ -12,6 +12,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.features.rate_limiting import RateLimit | |||
| from models.model import Account, App, AppMode, EndUser | |||
| from models.workflow import Workflow | |||
| from services.errors.llm import InvokeRateLimitError | |||
| from services.workflow_service import WorkflowService | |||
| @@ -103,9 +104,7 @@ class AppGenerateService: | |||
| return max_active_requests | |||
| @classmethod | |||
| def generate_single_iteration( | |||
| cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True | |||
| ): | |||
| def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): | |||
| if app_model.mode == AppMode.ADVANCED_CHAT.value: | |||
| workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) | |||
| return AdvancedChatAppGenerator().single_iteration_generate( | |||
| @@ -142,7 +141,7 @@ class AppGenerateService: | |||
| ) | |||
| @classmethod | |||
| def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any: | |||
| def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow: | |||
| """ | |||
| Get workflow | |||
| :param app_model: app model | |||
| @@ -8,9 +8,11 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | |||
| from core.app.segments import Variable | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.entities.node_entities import NodeRunResult, NodeType | |||
| from core.workflow.errors import WorkflowNodeRunFailedError | |||
| from core.workflow.workflow_engine_manager import WorkflowEngineManager | |||
| from core.workflow.nodes.event import RunCompletedEvent | |||
| from core.workflow.nodes.node_mapping import node_classes | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| @@ -172,8 +174,13 @@ class WorkflowService: | |||
| Get default block configs | |||
| """ | |||
| # return default block config | |||
| workflow_engine_manager = WorkflowEngineManager() | |||
| return workflow_engine_manager.get_default_configs() | |||
| default_block_configs = [] | |||
| for node_type, node_class in node_classes.items(): | |||
| default_config = node_class.get_default_config() | |||
| if default_config: | |||
| default_block_configs.append(default_config) | |||
| return default_block_configs | |||
| def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: | |||
| """ | |||
| @@ -182,11 +189,18 @@ class WorkflowService: | |||
| :param filters: filter by node config parameters. | |||
| :return: | |||
| """ | |||
| node_type = NodeType.value_of(node_type) | |||
| node_type_enum: NodeType = NodeType.value_of(node_type) | |||
| # return default block config | |||
| workflow_engine_manager = WorkflowEngineManager() | |||
| return workflow_engine_manager.get_default_config(node_type, filters) | |||
| node_class = node_classes.get(node_type_enum) | |||
| if not node_class: | |||
| return None | |||
| default_config = node_class.get_default_config(filters=filters) | |||
| if not default_config: | |||
| return None | |||
| return default_config | |||
| def run_draft_workflow_node( | |||
| self, app_model: App, node_id: str, user_inputs: dict, account: Account | |||
| @@ -200,82 +214,68 @@ class WorkflowService: | |||
| raise ValueError("Workflow not initialized") | |||
| # run draft workflow node | |||
| workflow_engine_manager = WorkflowEngineManager() | |||
| start_at = time.perf_counter() | |||
| try: | |||
| node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node( | |||
| node_instance, generator = WorkflowEntry.single_step_run( | |||
| workflow=draft_workflow, | |||
| node_id=node_id, | |||
| user_inputs=user_inputs, | |||
| user_id=account.id, | |||
| ) | |||
| except WorkflowNodeRunFailedError as e: | |||
| workflow_node_execution = WorkflowNodeExecution( | |||
| tenant_id=app_model.tenant_id, | |||
| app_id=app_model.id, | |||
| workflow_id=draft_workflow.id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, | |||
| index=1, | |||
| node_id=e.node_id, | |||
| node_type=e.node_type.value, | |||
| title=e.node_title, | |||
| status=WorkflowNodeExecutionStatus.FAILED.value, | |||
| error=e.error, | |||
| elapsed_time=time.perf_counter() - start_at, | |||
| created_by_role=CreatedByRole.ACCOUNT.value, | |||
| created_by=account.id, | |||
| created_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| finished_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| ) | |||
| db.session.add(workflow_node_execution) | |||
| db.session.commit() | |||
| return workflow_node_execution | |||
| node_run_result: NodeRunResult | None = None | |||
| for event in generator: | |||
| if isinstance(event, RunCompletedEvent): | |||
| node_run_result = event.run_result | |||
| # sign output files | |||
| node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) | |||
| break | |||
| if not node_run_result: | |||
| raise ValueError("Node run failed with no run result") | |||
| if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: | |||
| run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False | |||
| error = node_run_result.error if not run_succeeded else None | |||
| except WorkflowNodeRunFailedError as e: | |||
| node_instance = e.node_instance | |||
| run_succeeded = False | |||
| node_run_result = None | |||
| error = e.error | |||
| workflow_node_execution = WorkflowNodeExecution() | |||
| workflow_node_execution.tenant_id = app_model.tenant_id | |||
| workflow_node_execution.app_id = app_model.id | |||
| workflow_node_execution.workflow_id = draft_workflow.id | |||
| workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value | |||
| workflow_node_execution.index = 1 | |||
| workflow_node_execution.node_id = node_id | |||
| workflow_node_execution.node_type = node_instance.node_type.value | |||
| workflow_node_execution.title = node_instance.node_data.title | |||
| workflow_node_execution.elapsed_time = time.perf_counter() - start_at | |||
| workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value | |||
| workflow_node_execution.created_by = account.id | |||
| workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| if run_succeeded and node_run_result: | |||
| # create workflow node execution | |||
| workflow_node_execution = WorkflowNodeExecution( | |||
| tenant_id=app_model.tenant_id, | |||
| app_id=app_model.id, | |||
| workflow_id=draft_workflow.id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, | |||
| index=1, | |||
| node_id=node_id, | |||
| node_type=node_instance.node_type.value, | |||
| title=node_instance.node_data.title, | |||
| inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, | |||
| process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, | |||
| outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None, | |||
| execution_metadata=( | |||
| json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None | |||
| ), | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED.value, | |||
| elapsed_time=time.perf_counter() - start_at, | |||
| created_by_role=CreatedByRole.ACCOUNT.value, | |||
| created_by=account.id, | |||
| created_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| finished_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None | |||
| workflow_node_execution.process_data = ( | |||
| json.dumps(node_run_result.process_data) if node_run_result.process_data else None | |||
| ) | |||
| workflow_node_execution.outputs = ( | |||
| json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None | |||
| ) | |||
| workflow_node_execution.execution_metadata = ( | |||
| json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None | |||
| ) | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value | |||
| else: | |||
| # create workflow node execution | |||
| workflow_node_execution = WorkflowNodeExecution( | |||
| tenant_id=app_model.tenant_id, | |||
| app_id=app_model.id, | |||
| workflow_id=draft_workflow.id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, | |||
| index=1, | |||
| node_id=node_id, | |||
| node_type=node_instance.node_type.value, | |||
| title=node_instance.node_data.title, | |||
| status=node_run_result.status.value, | |||
| error=node_run_result.error, | |||
| elapsed_time=time.perf_counter() - start_at, | |||
| created_by_role=CreatedByRole.ACCOUNT.value, | |||
| created_by=account.id, | |||
| created_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| finished_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| ) | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | |||
| workflow_node_execution.error = error | |||
| db.session.add(workflow_node_execution) | |||
| db.session.commit() | |||
| @@ -321,25 +321,3 @@ class WorkflowService: | |||
| ) | |||
| else: | |||
| raise ValueError(f"Invalid app mode: {app_model.mode}") | |||
| @classmethod | |||
| def get_elapsed_time(cls, workflow_run_id: str) -> float: | |||
| """ | |||
| Get elapsed time | |||
| """ | |||
| elapsed_time = 0.0 | |||
| # fetch workflow node execution by workflow_run_id | |||
| workflow_nodes = ( | |||
| db.session.query(WorkflowNodeExecution) | |||
| .filter(WorkflowNodeExecution.workflow_run_id == workflow_run_id) | |||
| .order_by(WorkflowNodeExecution.created_at.asc()) | |||
| .all() | |||
| ) | |||
| if not workflow_nodes: | |||
| return elapsed_time | |||
| for node in workflow_nodes: | |||
| elapsed_time += node.elapsed_time | |||
| return elapsed_time | |||
| @@ -1,17 +1,72 @@ | |||
| import time | |||
| import uuid | |||
| from os import getenv | |||
| from typing import cast | |||
| import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import NodeRunResult, UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes.code.code_node import CodeNode | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.code.entities import CodeNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus, WorkflowType | |||
| from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock | |||
| CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) | |||
| def init_code_node(code_config: dict): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-code-target", | |||
| "source": "start", | |||
| "target": "code", | |||
| }, | |||
| ], | |||
| "nodes": [{"data": {"type": "start"}, "id": "start"}, code_config], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| conversation_variables=[], | |||
| ) | |||
| variable_pool.add(["code", "123", "args1"], 1) | |||
| variable_pool.add(["code", "123", "args2"], 2) | |||
| node = CodeNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph=graph, | |||
| graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), | |||
| config=code_config, | |||
| ) | |||
| return node | |||
| @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) | |||
| def test_execute_code(setup_code_executor_mock): | |||
| code = """ | |||
| @@ -22,44 +77,36 @@ def test_execute_code(setup_code_executor_mock): | |||
| """ | |||
| # trim first 4 spaces at the beginning of each line | |||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||
| node = CodeNode( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| "outputs": { | |||
| "result": { | |||
| "type": "number", | |||
| }, | |||
| code_config = { | |||
| "id": "code", | |||
| "data": { | |||
| "outputs": { | |||
| "result": { | |||
| "type": "number", | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "123", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "123", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "123", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "123", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| ) | |||
| } | |||
| # construct variable pool | |||
| pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) | |||
| pool.add(["1", "123", "args1"], 1) | |||
| pool.add(["1", "123", "args2"], 2) | |||
| node = init_code_node(code_config) | |||
| # execute node | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert isinstance(result, NodeRunResult) | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert result.outputs["result"] == 3 | |||
| assert result.error is None | |||
| @@ -74,44 +121,34 @@ def test_execute_code_output_validator(setup_code_executor_mock): | |||
| """ | |||
| # trim first 4 spaces at the beginning of each line | |||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||
| node = CodeNode( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| "outputs": { | |||
| "result": { | |||
| "type": "string", | |||
| }, | |||
| code_config = { | |||
| "id": "code", | |||
| "data": { | |||
| "outputs": { | |||
| "result": { | |||
| "type": "string", | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "123", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "123", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "123", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "123", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| ) | |||
| } | |||
| # construct variable pool | |||
| pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) | |||
| pool.add(["1", "123", "args1"], 1) | |||
| pool.add(["1", "123", "args2"], 2) | |||
| node = init_code_node(code_config) | |||
| # execute node | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert isinstance(result, NodeRunResult) | |||
| assert result.status == WorkflowNodeExecutionStatus.FAILED | |||
| assert result.error == "Output variable `result` must be a string" | |||
| @@ -127,65 +164,60 @@ def test_execute_code_output_validator_depth(): | |||
| """ | |||
| # trim first 4 spaces at the beginning of each line | |||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||
| node = CodeNode( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| "outputs": { | |||
| "string_validator": { | |||
| "type": "string", | |||
| }, | |||
| "number_validator": { | |||
| "type": "number", | |||
| }, | |||
| "number_array_validator": { | |||
| "type": "array[number]", | |||
| }, | |||
| "string_array_validator": { | |||
| "type": "array[string]", | |||
| }, | |||
| "object_validator": { | |||
| "type": "object", | |||
| "children": { | |||
| "result": { | |||
| "type": "number", | |||
| }, | |||
| "depth": { | |||
| "type": "object", | |||
| "children": { | |||
| "depth": { | |||
| "type": "object", | |||
| "children": { | |||
| "depth": { | |||
| "type": "number", | |||
| } | |||
| }, | |||
| } | |||
| }, | |||
| code_config = { | |||
| "id": "code", | |||
| "data": { | |||
| "outputs": { | |||
| "string_validator": { | |||
| "type": "string", | |||
| }, | |||
| "number_validator": { | |||
| "type": "number", | |||
| }, | |||
| "number_array_validator": { | |||
| "type": "array[number]", | |||
| }, | |||
| "string_array_validator": { | |||
| "type": "array[string]", | |||
| }, | |||
| "object_validator": { | |||
| "type": "object", | |||
| "children": { | |||
| "result": { | |||
| "type": "number", | |||
| }, | |||
| "depth": { | |||
| "type": "object", | |||
| "children": { | |||
| "depth": { | |||
| "type": "object", | |||
| "children": { | |||
| "depth": { | |||
| "type": "number", | |||
| } | |||
| }, | |||
| } | |||
| }, | |||
| }, | |||
| }, | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "123", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "123", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "123", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "123", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| ) | |||
| } | |||
| node = init_code_node(code_config) | |||
| # construct result | |||
| result = { | |||
| @@ -196,6 +228,8 @@ def test_execute_code_output_validator_depth(): | |||
| "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, | |||
| } | |||
| node.node_data = cast(CodeNodeData, node.node_data) | |||
| # validate | |||
| node._transform_result(result, node.node_data.outputs) | |||
| @@ -250,35 +284,30 @@ def test_execute_code_output_object_list(): | |||
| """ | |||
| # trim first 4 spaces at the beginning of each line | |||
| code = "\n".join([line[4:] for line in code.split("\n")]) | |||
| node = CodeNode( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.ACCOUNT, | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| "outputs": { | |||
| "object_list": { | |||
| "type": "array[object]", | |||
| }, | |||
| code_config = { | |||
| "id": "code", | |||
| "data": { | |||
| "outputs": { | |||
| "object_list": { | |||
| "type": "array[object]", | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "123", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "123", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "123", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "123", "args2"]}, | |||
| ], | |||
| "answer": "123", | |||
| "code_language": "python3", | |||
| "code": code, | |||
| }, | |||
| ) | |||
| } | |||
| node = init_code_node(code_config) | |||
| # construct result | |||
| result = { | |||
| @@ -295,6 +324,8 @@ def test_execute_code_output_object_list(): | |||
| ] | |||
| } | |||
| node.node_data = cast(CodeNodeData, node.node_data) | |||
| # validate | |||
| node._transform_result(result, node.node_data.outputs) | |||
| @@ -1,31 +1,69 @@ | |||
| import time | |||
| import uuid | |||
| from urllib.parse import urlencode | |||
| import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes.http_request.http_request_node import HttpRequestNode | |||
| from models.workflow import WorkflowType | |||
| from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock | |||
| BASIC_NODE_DATA = { | |||
| "tenant_id": "1", | |||
| "app_id": "1", | |||
| "workflow_id": "1", | |||
| "user_id": "1", | |||
| "user_from": UserFrom.ACCOUNT, | |||
| "invoke_from": InvokeFrom.WEB_APP, | |||
| } | |||
| # construct variable pool | |||
| pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) | |||
| pool.add(["a", "b123", "args1"], 1) | |||
| pool.add(["a", "b123", "args2"], 2) | |||
| def init_http_node(config: dict): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-next-target", | |||
| "source": "start", | |||
| "target": "1", | |||
| }, | |||
| ], | |||
| "nodes": [{"data": {"type": "start"}, "id": "start"}, config], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| conversation_variables=[], | |||
| ) | |||
| variable_pool.add(["a", "b123", "args1"], 1) | |||
| variable_pool.add(["a", "b123", "args2"], 2) | |||
| return HttpRequestNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph=graph, | |||
| graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), | |||
| config=config, | |||
| ) | |||
| @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) | |||
| def test_get(setup_http_mock): | |||
| node = HttpRequestNode( | |||
| node = init_http_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -45,12 +83,11 @@ def test_get(setup_http_mock): | |||
| "params": "A:b", | |||
| "body": None, | |||
| }, | |||
| }, | |||
| **BASIC_NODE_DATA, | |||
| } | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.process_data is not None | |||
| data = result.process_data.get("request", "") | |||
| assert "?A=b" in data | |||
| @@ -59,7 +96,7 @@ def test_get(setup_http_mock): | |||
| @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) | |||
| def test_no_auth(setup_http_mock): | |||
| node = HttpRequestNode( | |||
| node = init_http_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -75,12 +112,11 @@ def test_no_auth(setup_http_mock): | |||
| "params": "A:b", | |||
| "body": None, | |||
| }, | |||
| }, | |||
| **BASIC_NODE_DATA, | |||
| } | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.process_data is not None | |||
| data = result.process_data.get("request", "") | |||
| assert "?A=b" in data | |||
| @@ -89,7 +125,7 @@ def test_no_auth(setup_http_mock): | |||
| @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) | |||
| def test_custom_authorization_header(setup_http_mock): | |||
| node = HttpRequestNode( | |||
| node = init_http_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -109,12 +145,11 @@ def test_custom_authorization_header(setup_http_mock): | |||
| "params": "A:b", | |||
| "body": None, | |||
| }, | |||
| }, | |||
| **BASIC_NODE_DATA, | |||
| } | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.process_data is not None | |||
| data = result.process_data.get("request", "") | |||
| assert "?A=b" in data | |||
| @@ -123,7 +158,7 @@ def test_custom_authorization_header(setup_http_mock): | |||
| @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) | |||
| def test_template(setup_http_mock): | |||
| node = HttpRequestNode( | |||
| node = init_http_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -143,11 +178,11 @@ def test_template(setup_http_mock): | |||
| "params": "A:b\nTemplate:{{#a.b123.args2#}}", | |||
| "body": None, | |||
| }, | |||
| }, | |||
| **BASIC_NODE_DATA, | |||
| } | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.process_data is not None | |||
| data = result.process_data.get("request", "") | |||
| assert "?A=b" in data | |||
| @@ -158,7 +193,7 @@ def test_template(setup_http_mock): | |||
| @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) | |||
| def test_json(setup_http_mock): | |||
| node = HttpRequestNode( | |||
| node = init_http_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -178,11 +213,11 @@ def test_json(setup_http_mock): | |||
| "params": "A:b", | |||
| "body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'}, | |||
| }, | |||
| }, | |||
| **BASIC_NODE_DATA, | |||
| } | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.process_data is not None | |||
| data = result.process_data.get("request", "") | |||
| assert '{"a": "1"}' in data | |||
| @@ -190,7 +225,7 @@ def test_json(setup_http_mock): | |||
| def test_x_www_form_urlencoded(setup_http_mock): | |||
| node = HttpRequestNode( | |||
| node = init_http_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -210,11 +245,11 @@ def test_x_www_form_urlencoded(setup_http_mock): | |||
| "params": "A:b", | |||
| "body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, | |||
| }, | |||
| }, | |||
| **BASIC_NODE_DATA, | |||
| } | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.process_data is not None | |||
| data = result.process_data.get("request", "") | |||
| assert "a=1&b=2" in data | |||
| @@ -222,7 +257,7 @@ def test_x_www_form_urlencoded(setup_http_mock): | |||
| def test_form_data(setup_http_mock): | |||
| node = HttpRequestNode( | |||
| node = init_http_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -242,11 +277,11 @@ def test_form_data(setup_http_mock): | |||
| "params": "A:b", | |||
| "body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, | |||
| }, | |||
| }, | |||
| **BASIC_NODE_DATA, | |||
| } | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.process_data is not None | |||
| data = result.process_data.get("request", "") | |||
| assert 'form-data; name="a"' in data | |||
| @@ -257,7 +292,7 @@ def test_form_data(setup_http_mock): | |||
| def test_none_data(setup_http_mock): | |||
| node = HttpRequestNode( | |||
| node = init_http_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -277,11 +312,11 @@ def test_none_data(setup_http_mock): | |||
| "params": "A:b", | |||
| "body": {"type": "none", "data": "123123123"}, | |||
| }, | |||
| }, | |||
| **BASIC_NODE_DATA, | |||
| } | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.process_data is not None | |||
| data = result.process_data.get("request", "") | |||
| assert "X-Header: 123" in data | |||
| @@ -289,7 +324,7 @@ def test_none_data(setup_http_mock): | |||
| def test_mock_404(setup_http_mock): | |||
| node = HttpRequestNode( | |||
| node = init_http_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -305,19 +340,19 @@ def test_mock_404(setup_http_mock): | |||
| "params": "", | |||
| "headers": "X-Header:123", | |||
| }, | |||
| }, | |||
| **BASIC_NODE_DATA, | |||
| } | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.outputs is not None | |||
| resp = result.outputs | |||
| assert 404 == resp.get("status_code") | |||
| assert "Not Found" in resp.get("body") | |||
| assert "Not Found" in resp.get("body", "") | |||
| def test_multi_colons_parse(setup_http_mock): | |||
| node = HttpRequestNode( | |||
| node = init_http_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -333,13 +368,14 @@ def test_multi_colons_parse(setup_http_mock): | |||
| "headers": "Referer:http://example3.com\nRedirect:http://example4.com", | |||
| "body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"}, | |||
| }, | |||
| }, | |||
| **BASIC_NODE_DATA, | |||
| } | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.process_data is not None | |||
| assert result.outputs is not None | |||
| resp = result.outputs | |||
| assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request") | |||
| assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request") | |||
| assert "http://example3.com" == resp.get("headers").get("referer") | |||
| assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "") | |||
| assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request", "") | |||
| assert "http://example3.com" == resp.get("headers", {}).get("referer") | |||
| @@ -1,5 +1,8 @@ | |||
| import json | |||
| import os | |||
| import time | |||
| import uuid | |||
| from collections.abc import Generator | |||
| from unittest.mock import MagicMock | |||
| import pytest | |||
| @@ -10,28 +13,77 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.model_providers import ModelProviderFactory | |||
| from core.workflow.entities.node_entities import UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes.event import RunCompletedEvent | |||
| from core.workflow.nodes.llm.llm_node import LLMNode | |||
| from extensions.ext_database import db | |||
| from models.provider import ProviderType | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from models.workflow import WorkflowNodeExecutionStatus, WorkflowType | |||
| """FOR MOCK FIXTURES, DO NOT REMOVE""" | |||
| from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock | |||
| from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock | |||
| @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) | |||
| def test_execute_llm(setup_openai_mock): | |||
| node = LLMNode( | |||
| def init_llm_node(config: dict) -> LLMNode: | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-next-target", | |||
| "source": "start", | |||
| "target": "llm", | |||
| }, | |||
| ], | |||
| "nodes": [{"data": {"type": "start"}, "id": "start"}, config], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "what's the weather today?", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "aaa", | |||
| }, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| conversation_variables=[], | |||
| ) | |||
| variable_pool.add(["abc", "output"], "sunny") | |||
| node = LLMNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph=graph, | |||
| graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), | |||
| config=config, | |||
| ) | |||
| return node | |||
| @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) | |||
| def test_execute_llm(setup_openai_mock): | |||
| node = init_llm_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| @@ -49,19 +101,6 @@ def test_execute_llm(setup_openai_mock): | |||
| }, | |||
| ) | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "what's the weather today?", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "aaa", | |||
| }, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| pool.add(["abc", "output"], "sunny") | |||
| credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} | |||
| provider_instance = ModelProviderFactory().get_provider_instance("openai") | |||
| @@ -80,13 +119,15 @@ def test_execute_llm(setup_openai_mock): | |||
| model_type_instance=model_type_instance, | |||
| ) | |||
| model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") | |||
| model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo") | |||
| assert model_schema is not None | |||
| model_config = ModelConfigWithCredentialsEntity( | |||
| model="gpt-3.5-turbo", | |||
| provider="openai", | |||
| mode="chat", | |||
| credentials=credentials, | |||
| parameters={}, | |||
| model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), | |||
| model_schema=model_schema, | |||
| provider_model_bundle=provider_model_bundle, | |||
| ) | |||
| @@ -96,11 +137,16 @@ def test_execute_llm(setup_openai_mock): | |||
| node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) | |||
| # execute node | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert isinstance(result, Generator) | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs["text"] is not None | |||
| assert result.outputs["usage"]["total_tokens"] > 0 | |||
| for item in result: | |||
| if isinstance(item, RunCompletedEvent): | |||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.run_result.process_data is not None | |||
| assert item.run_result.outputs is not None | |||
| assert item.run_result.outputs.get("text") is not None | |||
| assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 | |||
| @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) | |||
| @@ -109,13 +155,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): | |||
| """ | |||
| Test execute LLM node with jinja2 | |||
| """ | |||
| node = LLMNode( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.ACCOUNT, | |||
| node = init_llm_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| @@ -149,19 +189,6 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): | |||
| }, | |||
| ) | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "what's the weather today?", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "aaa", | |||
| }, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| pool.add(["abc", "output"], "sunny") | |||
| credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} | |||
| provider_instance = ModelProviderFactory().get_provider_instance("openai") | |||
| @@ -181,14 +208,15 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): | |||
| ) | |||
| model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") | |||
| model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo") | |||
| assert model_schema is not None | |||
| model_config = ModelConfigWithCredentialsEntity( | |||
| model="gpt-3.5-turbo", | |||
| provider="openai", | |||
| mode="chat", | |||
| credentials=credentials, | |||
| parameters={}, | |||
| model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), | |||
| model_schema=model_schema, | |||
| provider_model_bundle=provider_model_bundle, | |||
| ) | |||
| @@ -198,8 +226,11 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): | |||
| node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) | |||
| # execute node | |||
| result = node.run(pool) | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert "sunny" in json.dumps(result.process_data) | |||
| assert "what's the weather today?" in json.dumps(result.process_data) | |||
| result = node._run() | |||
| for item in result: | |||
| if isinstance(item, RunCompletedEvent): | |||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.run_result.process_data is not None | |||
| assert "sunny" in json.dumps(item.run_result.process_data) | |||
| assert "what's the weather today?" in json.dumps(item.run_result.process_data) | |||
| @@ -1,5 +1,7 @@ | |||
| import json | |||
| import os | |||
| import time | |||
| import uuid | |||
| from typing import Optional | |||
| from unittest.mock import MagicMock | |||
| @@ -8,19 +10,21 @@ import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity | |||
| from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle | |||
| from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.workflow.entities.node_entities import UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode | |||
| from extensions.ext_database import db | |||
| from models.provider import ProviderType | |||
| """FOR MOCK FIXTURES, DO NOT REMOVE""" | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from models.workflow import WorkflowNodeExecutionStatus, WorkflowType | |||
| from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock | |||
| from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock | |||
| @@ -47,13 +51,15 @@ def get_mocked_fetch_model_config( | |||
| model_type_instance=model_type_instance, | |||
| ) | |||
| model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model) | |||
| model_schema = model_type_instance.get_model_schema(model) | |||
| assert model_schema is not None | |||
| model_config = ModelConfigWithCredentialsEntity( | |||
| model=model, | |||
| provider=provider, | |||
| mode=mode, | |||
| credentials=credentials, | |||
| parameters={}, | |||
| model_schema=model_type_instance.get_model_schema(model), | |||
| model_schema=model_schema, | |||
| provider_model_bundle=provider_model_bundle, | |||
| ) | |||
| @@ -74,18 +80,62 @@ def get_mocked_fetch_memory(memory_text: str): | |||
| return MagicMock(return_value=MemoryMock()) | |||
| @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) | |||
| def test_function_calling_parameter_extractor(setup_openai_mock): | |||
| """ | |||
| Test function calling for parameter extractor. | |||
| """ | |||
| node = ParameterExtractorNode( | |||
| def init_parameter_extractor_node(config: dict): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-next-target", | |||
| "source": "start", | |||
| "target": "llm", | |||
| }, | |||
| ], | |||
| "nodes": [{"data": {"type": "start"}, "id": "start"}, config], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "what's the weather in SF", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "aaa", | |||
| }, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| conversation_variables=[], | |||
| ) | |||
| variable_pool.add(["a", "b123", "args1"], 1) | |||
| variable_pool.add(["a", "b123", "args2"], 2) | |||
| return ParameterExtractorNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph=graph, | |||
| graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), | |||
| config=config, | |||
| ) | |||
| @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) | |||
| def test_function_calling_parameter_extractor(setup_openai_mock): | |||
| """ | |||
| Test function calling for parameter extractor. | |||
| """ | |||
| node = init_parameter_extractor_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| @@ -98,7 +148,7 @@ def test_function_calling_parameter_extractor(setup_openai_mock): | |||
| "reasoning_mode": "function_call", | |||
| "memory": None, | |||
| }, | |||
| }, | |||
| } | |||
| ) | |||
| node._fetch_model_config = get_mocked_fetch_model_config( | |||
| @@ -121,9 +171,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock): | |||
| environment_variables=[], | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert result.outputs.get("location") == "kawaii" | |||
| assert result.outputs.get("__reason") == None | |||
| @@ -133,13 +184,7 @@ def test_instructions(setup_openai_mock): | |||
| """ | |||
| Test chat parameter extractor. | |||
| """ | |||
| node = ParameterExtractorNode( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.ACCOUNT, | |||
| node = init_parameter_extractor_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| @@ -163,29 +208,19 @@ def test_instructions(setup_openai_mock): | |||
| ) | |||
| db.session.close = MagicMock() | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "what's the weather in SF", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "aaa", | |||
| }, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert result.outputs.get("location") == "kawaii" | |||
| assert result.outputs.get("__reason") == None | |||
| process_data = result.process_data | |||
| assert process_data is not None | |||
| process_data.get("prompts") | |||
| for prompt in process_data.get("prompts"): | |||
| for prompt in process_data.get("prompts", []): | |||
| if prompt.get("role") == "system": | |||
| assert "what's the weather in SF" in prompt.get("text") | |||
| @@ -195,13 +230,7 @@ def test_chat_parameter_extractor(setup_anthropic_mock): | |||
| """ | |||
| Test chat parameter extractor. | |||
| """ | |||
| node = ParameterExtractorNode( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.ACCOUNT, | |||
| node = init_parameter_extractor_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| @@ -225,27 +254,17 @@ def test_chat_parameter_extractor(setup_anthropic_mock): | |||
| ) | |||
| db.session.close = MagicMock() | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "what's the weather in SF", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "aaa", | |||
| }, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert result.outputs.get("location") == "" | |||
| assert ( | |||
| result.outputs.get("__reason") | |||
| == "Failed to extract result from function call or text response, using empty result." | |||
| ) | |||
| prompts = result.process_data.get("prompts") | |||
| assert result.process_data is not None | |||
| prompts = result.process_data.get("prompts", []) | |||
| for prompt in prompts: | |||
| if prompt.get("role") == "user": | |||
| @@ -258,13 +277,7 @@ def test_completion_parameter_extractor(setup_openai_mock): | |||
| """ | |||
| Test completion parameter extractor. | |||
| """ | |||
| node = ParameterExtractorNode( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.ACCOUNT, | |||
| node = init_parameter_extractor_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| @@ -293,28 +306,18 @@ def test_completion_parameter_extractor(setup_openai_mock): | |||
| ) | |||
| db.session.close = MagicMock() | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "what's the weather in SF", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "aaa", | |||
| }, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert result.outputs.get("location") == "" | |||
| assert ( | |||
| result.outputs.get("__reason") | |||
| == "Failed to extract result from function call or text response, using empty result." | |||
| ) | |||
| assert len(result.process_data.get("prompts")) == 1 | |||
| assert "SF" in result.process_data.get("prompts")[0].get("text") | |||
| assert result.process_data is not None | |||
| assert len(result.process_data.get("prompts", [])) == 1 | |||
| assert "SF" in result.process_data.get("prompts", [])[0].get("text") | |||
| def test_extract_json_response(): | |||
| @@ -322,13 +325,7 @@ def test_extract_json_response(): | |||
| Test extract json response. | |||
| """ | |||
| node = ParameterExtractorNode( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.ACCOUNT, | |||
| node = init_parameter_extractor_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| @@ -357,6 +354,7 @@ def test_extract_json_response(): | |||
| hello world. | |||
| """) | |||
| assert result is not None | |||
| assert result["location"] == "kawaii" | |||
| @@ -365,13 +363,7 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): | |||
| """ | |||
| Test chat parameter extractor with memory. | |||
| """ | |||
| node = ParameterExtractorNode( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.ACCOUNT, | |||
| node = init_parameter_extractor_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| @@ -396,27 +388,17 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): | |||
| node._fetch_memory = get_mocked_fetch_memory("customized memory") | |||
| db.session.close = MagicMock() | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "what's the weather in SF", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "aaa", | |||
| }, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert result.outputs.get("location") == "" | |||
| assert ( | |||
| result.outputs.get("__reason") | |||
| == "Failed to extract result from function call or text response, using empty result." | |||
| ) | |||
| prompts = result.process_data.get("prompts") | |||
| assert result.process_data is not None | |||
| prompts = result.process_data.get("prompts", []) | |||
| latest_role = None | |||
| for prompt in prompts: | |||
| @@ -1,46 +1,84 @@ | |||
| import time | |||
| import uuid | |||
| import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from models.workflow import WorkflowNodeExecutionStatus, WorkflowType | |||
| from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock | |||
| @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) | |||
| def test_execute_code(setup_code_executor_mock): | |||
| code = """{{args2}}""" | |||
| node = TemplateTransformNode( | |||
| config = { | |||
| "id": "1", | |||
| "data": { | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "123", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "123", "args2"]}, | |||
| ], | |||
| "template": code, | |||
| }, | |||
| } | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-next-target", | |||
| "source": "start", | |||
| "target": "1", | |||
| }, | |||
| ], | |||
| "nodes": [{"data": {"type": "start"}, "id": "start"}, config], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.END_USER, | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| "title": "123", | |||
| "variables": [ | |||
| { | |||
| "variable": "args1", | |||
| "value_selector": ["1", "123", "args1"], | |||
| }, | |||
| {"variable": "args2", "value_selector": ["1", "123", "args2"]}, | |||
| ], | |||
| "template": code, | |||
| }, | |||
| }, | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) | |||
| pool.add(["1", "123", "args1"], 1) | |||
| pool.add(["1", "123", "args2"], 3) | |||
| variable_pool = VariablePool( | |||
| system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| conversation_variables=[], | |||
| ) | |||
| variable_pool.add(["1", "123", "args1"], 1) | |||
| variable_pool.add(["1", "123", "args2"], 3) | |||
| node = TemplateTransformNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph=graph, | |||
| graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), | |||
| config=config, | |||
| ) | |||
| # execute node | |||
| result = node.run(pool) | |||
| result = node._run() | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert result.outputs["output"] == "3" | |||
| @@ -1,21 +1,62 @@ | |||
| import time | |||
| import uuid | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import NodeRunResult, UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes.tool.tool_node import ToolNode | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from models.workflow import WorkflowNodeExecutionStatus, WorkflowType | |||
| def test_tool_variable_invoke(): | |||
| pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) | |||
| pool.add(["1", "123", "args1"], "1+1") | |||
| def init_tool_node(config: dict): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-next-target", | |||
| "source": "start", | |||
| "target": "1", | |||
| }, | |||
| ], | |||
| "nodes": [{"data": {"type": "start"}, "id": "start"}, config], | |||
| } | |||
| node = ToolNode( | |||
| graph = Graph.init(graph_config=graph_config) | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| conversation_variables=[], | |||
| ) | |||
| return ToolNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph=graph, | |||
| graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), | |||
| config=config, | |||
| ) | |||
| def test_tool_variable_invoke(): | |||
| node = init_tool_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -34,28 +75,22 @@ def test_tool_variable_invoke(): | |||
| } | |||
| }, | |||
| }, | |||
| }, | |||
| } | |||
| ) | |||
| # execute node | |||
| result = node.run(pool) | |||
| node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1") | |||
| # execute node | |||
| result = node._run() | |||
| assert isinstance(result, NodeRunResult) | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert "2" in result.outputs["text"] | |||
| assert result.outputs["files"] == [] | |||
| def test_tool_mixed_invoke(): | |||
| pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) | |||
| pool.add(["1", "args1"], "1+1") | |||
| node = ToolNode( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| user_id="1", | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| user_from=UserFrom.ACCOUNT, | |||
| node = init_tool_node( | |||
| config={ | |||
| "id": "1", | |||
| "data": { | |||
| @@ -74,12 +109,15 @@ def test_tool_mixed_invoke(): | |||
| } | |||
| }, | |||
| }, | |||
| }, | |||
| } | |||
| ) | |||
| # execute node | |||
| result = node.run(pool) | |||
| node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") | |||
| # execute node | |||
| result = node._run() | |||
| assert isinstance(result, NodeRunResult) | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert "2" in result.outputs["text"] | |||
| assert result.outputs["files"] == [] | |||
| @@ -1,7 +1,24 @@ | |||
| import os | |||
| import pytest | |||
| from flask import Flask | |||
| # Getting the absolute path of the current file's directory | |||
| ABS_PATH = os.path.dirname(os.path.abspath(__file__)) | |||
| # Getting the absolute path of the project's root directory | |||
| PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) | |||
| CACHED_APP = Flask(__name__) | |||
| CACHED_APP.config.update({"TESTING": True}) | |||
| @pytest.fixture() | |||
| def app() -> Flask: | |||
| return CACHED_APP | |||
| @pytest.fixture(autouse=True) | |||
| def _provide_app_context(app: Flask): | |||
| with app.app_context(): | |||
| yield | |||
| @@ -0,0 +1,791 @@ | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.run_condition import RunCondition | |||
| from core.workflow.utils.condition.entities import Condition | |||
| def test_init(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "llm-source-answer-target", | |||
| "source": "llm", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "start-source-qc-target", | |||
| "source": "start", | |||
| "target": "qc", | |||
| }, | |||
| { | |||
| "id": "qc-1-llm-target", | |||
| "source": "qc", | |||
| "sourceHandle": "1", | |||
| "target": "llm", | |||
| }, | |||
| { | |||
| "id": "qc-2-http-target", | |||
| "source": "qc", | |||
| "sourceHandle": "2", | |||
| "target": "http", | |||
| }, | |||
| { | |||
| "id": "http-source-answer2-target", | |||
| "source": "http", | |||
| "target": "answer2", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"type": "start"}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer", "answer": "1"}, | |||
| "id": "answer", | |||
| }, | |||
| { | |||
| "data": {"type": "question-classifier"}, | |||
| "id": "qc", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "http-request", | |||
| }, | |||
| "id": "http", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer", "answer": "1"}, | |||
| "id": "answer2", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| start_node_id = "start" | |||
| assert graph.root_node_id == start_node_id | |||
| assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc" | |||
| assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")} | |||
| def test__init_iteration_graph(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "llm-answer", | |||
| "source": "llm", | |||
| "sourceHandle": "source", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "iteration-source-llm-target", | |||
| "source": "iteration", | |||
| "sourceHandle": "source", | |||
| "target": "llm", | |||
| }, | |||
| { | |||
| "id": "template-transform-in-iteration-source-llm-in-iteration-target", | |||
| "source": "template-transform-in-iteration", | |||
| "sourceHandle": "source", | |||
| "target": "llm-in-iteration", | |||
| }, | |||
| { | |||
| "id": "llm-in-iteration-source-answer-in-iteration-target", | |||
| "source": "llm-in-iteration", | |||
| "sourceHandle": "source", | |||
| "target": "answer-in-iteration", | |||
| }, | |||
| { | |||
| "id": "start-source-code-target", | |||
| "source": "start", | |||
| "sourceHandle": "source", | |||
| "target": "code", | |||
| }, | |||
| { | |||
| "id": "code-source-iteration-target", | |||
| "source": "code", | |||
| "sourceHandle": "source", | |||
| "target": "iteration", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| { | |||
| "data": { | |||
| "type": "start", | |||
| }, | |||
| "id": "start", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer", "answer": "1"}, | |||
| "id": "answer", | |||
| }, | |||
| { | |||
| "data": {"type": "iteration"}, | |||
| "id": "iteration", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "template-transform", | |||
| }, | |||
| "id": "template-transform-in-iteration", | |||
| "parentId": "iteration", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm-in-iteration", | |||
| "parentId": "iteration", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer", "answer": "1"}, | |||
| "id": "answer-in-iteration", | |||
| "parentId": "iteration", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "code", | |||
| }, | |||
| "id": "code", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration") | |||
| graph.add_extra_edge( | |||
| source_node_id="answer-in-iteration", | |||
| target_node_id="template-transform-in-iteration", | |||
| run_condition=RunCondition( | |||
| type="condition", | |||
| conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="≤", value="5")], | |||
| ), | |||
| ) | |||
| # iteration: | |||
| # [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration] | |||
| assert graph.root_node_id == "template-transform-in-iteration" | |||
| assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration" | |||
| assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration" | |||
| assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration" | |||
| def test_parallels_graph(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-llm1-target", | |||
| "source": "start", | |||
| "target": "llm1", | |||
| }, | |||
| { | |||
| "id": "start-source-llm2-target", | |||
| "source": "start", | |||
| "target": "llm2", | |||
| }, | |||
| { | |||
| "id": "start-source-llm3-target", | |||
| "source": "start", | |||
| "target": "llm3", | |||
| }, | |||
| { | |||
| "id": "llm1-source-answer-target", | |||
| "source": "llm1", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "llm2-source-answer-target", | |||
| "source": "llm2", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "llm3-source-answer-target", | |||
| "source": "llm3", | |||
| "target": "answer", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"type": "start"}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm3", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer", "answer": "1"}, | |||
| "id": "answer", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| assert graph.root_node_id == "start" | |||
| for i in range(3): | |||
| start_edges = graph.edge_mapping.get("start") | |||
| assert start_edges is not None | |||
| assert start_edges[i].target_node_id == f"llm{i+1}" | |||
| llm_edges = graph.edge_mapping.get(f"llm{i+1}") | |||
| assert llm_edges is not None | |||
| assert llm_edges[0].target_node_id == "answer" | |||
| assert len(graph.parallel_mapping) == 1 | |||
| assert len(graph.node_parallel_mapping) == 3 | |||
| for node_id in ["llm1", "llm2", "llm3"]: | |||
| assert node_id in graph.node_parallel_mapping | |||
| def test_parallels_graph2(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-llm1-target", | |||
| "source": "start", | |||
| "target": "llm1", | |||
| }, | |||
| { | |||
| "id": "start-source-llm2-target", | |||
| "source": "start", | |||
| "target": "llm2", | |||
| }, | |||
| { | |||
| "id": "start-source-llm3-target", | |||
| "source": "start", | |||
| "target": "llm3", | |||
| }, | |||
| { | |||
| "id": "llm1-source-answer-target", | |||
| "source": "llm1", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "llm2-source-answer-target", | |||
| "source": "llm2", | |||
| "target": "answer", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"type": "start"}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm3", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer", "answer": "1"}, | |||
| "id": "answer", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| assert graph.root_node_id == "start" | |||
| for i in range(3): | |||
| assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" | |||
| if i < 2: | |||
| assert graph.edge_mapping.get(f"llm{i + 1}") is not None | |||
| assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer" | |||
| assert len(graph.parallel_mapping) == 1 | |||
| assert len(graph.node_parallel_mapping) == 3 | |||
| for node_id in ["llm1", "llm2", "llm3"]: | |||
| assert node_id in graph.node_parallel_mapping | |||
| def test_parallels_graph3(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-llm1-target", | |||
| "source": "start", | |||
| "target": "llm1", | |||
| }, | |||
| { | |||
| "id": "start-source-llm2-target", | |||
| "source": "start", | |||
| "target": "llm2", | |||
| }, | |||
| { | |||
| "id": "start-source-llm3-target", | |||
| "source": "start", | |||
| "target": "llm3", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"type": "start"}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm3", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer", "answer": "1"}, | |||
| "id": "answer", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| assert graph.root_node_id == "start" | |||
| for i in range(3): | |||
| assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" | |||
| assert len(graph.parallel_mapping) == 1 | |||
| assert len(graph.node_parallel_mapping) == 3 | |||
| for node_id in ["llm1", "llm2", "llm3"]: | |||
| assert node_id in graph.node_parallel_mapping | |||
| def test_parallels_graph4(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-llm1-target", | |||
| "source": "start", | |||
| "target": "llm1", | |||
| }, | |||
| { | |||
| "id": "start-source-llm2-target", | |||
| "source": "start", | |||
| "target": "llm2", | |||
| }, | |||
| { | |||
| "id": "start-source-llm3-target", | |||
| "source": "start", | |||
| "target": "llm3", | |||
| }, | |||
| { | |||
| "id": "llm1-source-answer-target", | |||
| "source": "llm1", | |||
| "target": "code1", | |||
| }, | |||
| { | |||
| "id": "llm2-source-answer-target", | |||
| "source": "llm2", | |||
| "target": "code2", | |||
| }, | |||
| { | |||
| "id": "llm3-source-code3-target", | |||
| "source": "llm3", | |||
| "target": "code3", | |||
| }, | |||
| { | |||
| "id": "code1-source-answer-target", | |||
| "source": "code1", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "code2-source-answer-target", | |||
| "source": "code2", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "code3-source-answer-target", | |||
| "source": "code3", | |||
| "target": "answer", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"type": "start"}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "code", | |||
| }, | |||
| "id": "code1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "code", | |||
| }, | |||
| "id": "code2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm3", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "code", | |||
| }, | |||
| "id": "code3", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer", "answer": "1"}, | |||
| "id": "answer", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| assert graph.root_node_id == "start" | |||
| for i in range(3): | |||
| assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" | |||
| assert graph.edge_mapping.get(f"llm{i + 1}") is not None | |||
| assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}" | |||
| assert graph.edge_mapping.get(f"code{i + 1}") is not None | |||
| assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer" | |||
| assert len(graph.parallel_mapping) == 1 | |||
| assert len(graph.node_parallel_mapping) == 6 | |||
| for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: | |||
| assert node_id in graph.node_parallel_mapping | |||
| def test_parallels_graph5(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-llm1-target", | |||
| "source": "start", | |||
| "target": "llm1", | |||
| }, | |||
| { | |||
| "id": "start-source-llm2-target", | |||
| "source": "start", | |||
| "target": "llm2", | |||
| }, | |||
| { | |||
| "id": "start-source-llm3-target", | |||
| "source": "start", | |||
| "target": "llm3", | |||
| }, | |||
| { | |||
| "id": "start-source-llm3-target", | |||
| "source": "start", | |||
| "target": "llm4", | |||
| }, | |||
| { | |||
| "id": "start-source-llm3-target", | |||
| "source": "start", | |||
| "target": "llm5", | |||
| }, | |||
| { | |||
| "id": "llm1-source-code1-target", | |||
| "source": "llm1", | |||
| "target": "code1", | |||
| }, | |||
| { | |||
| "id": "llm2-source-code1-target", | |||
| "source": "llm2", | |||
| "target": "code1", | |||
| }, | |||
| { | |||
| "id": "llm3-source-code2-target", | |||
| "source": "llm3", | |||
| "target": "code2", | |||
| }, | |||
| { | |||
| "id": "llm4-source-code2-target", | |||
| "source": "llm4", | |||
| "target": "code2", | |||
| }, | |||
| { | |||
| "id": "llm5-source-code3-target", | |||
| "source": "llm5", | |||
| "target": "code3", | |||
| }, | |||
| { | |||
| "id": "code1-source-answer-target", | |||
| "source": "code1", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "code2-source-answer-target", | |||
| "source": "code2", | |||
| "target": "answer", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"type": "start"}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "code", | |||
| }, | |||
| "id": "code1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "code", | |||
| }, | |||
| "id": "code2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm3", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "code", | |||
| }, | |||
| "id": "code3", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer", "answer": "1"}, | |||
| "id": "answer", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm4", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm5", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| assert graph.root_node_id == "start" | |||
| for i in range(5): | |||
| assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" | |||
| assert graph.edge_mapping.get("llm1") is not None | |||
| assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" | |||
| assert graph.edge_mapping.get("llm2") is not None | |||
| assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1" | |||
| assert graph.edge_mapping.get("llm3") is not None | |||
| assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2" | |||
| assert graph.edge_mapping.get("llm4") is not None | |||
| assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2" | |||
| assert graph.edge_mapping.get("llm5") is not None | |||
| assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3" | |||
| assert graph.edge_mapping.get("code1") is not None | |||
| assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" | |||
| assert graph.edge_mapping.get("code2") is not None | |||
| assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" | |||
| assert len(graph.parallel_mapping) == 1 | |||
| assert len(graph.node_parallel_mapping) == 8 | |||
| for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]: | |||
| assert node_id in graph.node_parallel_mapping | |||
| def test_parallels_graph6(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-llm1-target", | |||
| "source": "start", | |||
| "target": "llm1", | |||
| }, | |||
| { | |||
| "id": "start-source-llm2-target", | |||
| "source": "start", | |||
| "target": "llm2", | |||
| }, | |||
| { | |||
| "id": "start-source-llm3-target", | |||
| "source": "start", | |||
| "target": "llm3", | |||
| }, | |||
| { | |||
| "id": "llm1-source-code1-target", | |||
| "source": "llm1", | |||
| "target": "code1", | |||
| }, | |||
| { | |||
| "id": "llm1-source-code2-target", | |||
| "source": "llm1", | |||
| "target": "code2", | |||
| }, | |||
| { | |||
| "id": "llm2-source-code3-target", | |||
| "source": "llm2", | |||
| "target": "code3", | |||
| }, | |||
| { | |||
| "id": "code1-source-answer-target", | |||
| "source": "code1", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "code2-source-answer-target", | |||
| "source": "code2", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "code3-source-answer-target", | |||
| "source": "code3", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "llm3-source-answer-target", | |||
| "source": "llm3", | |||
| "target": "answer", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"type": "start"}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "code", | |||
| }, | |||
| "id": "code1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "code", | |||
| }, | |||
| "id": "code2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm3", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "code", | |||
| }, | |||
| "id": "code3", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer", "answer": "1"}, | |||
| "id": "answer", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| assert graph.root_node_id == "start" | |||
| for i in range(3): | |||
| assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" | |||
| assert graph.edge_mapping.get("llm1") is not None | |||
| assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" | |||
| assert graph.edge_mapping.get("llm1") is not None | |||
| assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2" | |||
| assert graph.edge_mapping.get("llm2") is not None | |||
| assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3" | |||
| assert graph.edge_mapping.get("code1") is not None | |||
| assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" | |||
| assert graph.edge_mapping.get("code2") is not None | |||
| assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" | |||
| assert graph.edge_mapping.get("code3") is not None | |||
| assert graph.edge_mapping.get("code3")[0].target_node_id == "answer" | |||
| assert len(graph.parallel_mapping) == 2 | |||
| assert len(graph.node_parallel_mapping) == 6 | |||
| for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: | |||
| assert node_id in graph.node_parallel_mapping | |||
| parent_parallel = None | |||
| child_parallel = None | |||
| for p_id, parallel in graph.parallel_mapping.items(): | |||
| if parallel.parent_parallel_id is None: | |||
| parent_parallel = parallel | |||
| else: | |||
| child_parallel = parallel | |||
| for node_id in ["llm1", "llm2", "llm3", "code3"]: | |||
| assert graph.node_parallel_mapping[node_id] == parent_parallel.id | |||
| for node_id in ["code1", "code2"]: | |||
| assert graph.node_parallel_mapping[node_id] == child_parallel.id | |||
| @@ -0,0 +1,505 @@ | |||
| from unittest.mock import patch | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| BaseNodeEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunStartedEvent, | |||
| GraphRunSucceededEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunStartedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| NodeRunSucceededEvent, | |||
| ) | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState | |||
| from core.workflow.graph_engine.graph_engine import GraphEngine | |||
| from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent | |||
| from core.workflow.nodes.llm.llm_node import LLMNode | |||
| from models.workflow import WorkflowNodeExecutionStatus, WorkflowType | |||
| @patch("extensions.ext_database.db.session.remove") | |||
| @patch("extensions.ext_database.db.session.close") | |||
| def test_run_parallel_in_workflow(mock_close, mock_remove): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "1", | |||
| "source": "start", | |||
| "target": "llm1", | |||
| }, | |||
| { | |||
| "id": "2", | |||
| "source": "llm1", | |||
| "target": "llm2", | |||
| }, | |||
| { | |||
| "id": "3", | |||
| "source": "llm1", | |||
| "target": "llm3", | |||
| }, | |||
| { | |||
| "id": "4", | |||
| "source": "llm2", | |||
| "target": "end1", | |||
| }, | |||
| { | |||
| "id": "5", | |||
| "source": "llm3", | |||
| "target": "end2", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| { | |||
| "data": { | |||
| "type": "start", | |||
| "title": "start", | |||
| "variables": [ | |||
| { | |||
| "label": "query", | |||
| "max_length": 48, | |||
| "options": [], | |||
| "required": True, | |||
| "type": "text-input", | |||
| "variable": "query", | |||
| } | |||
| ], | |||
| }, | |||
| "id": "start", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| "title": "llm1", | |||
| "context": {"enabled": False, "variable_selector": []}, | |||
| "model": { | |||
| "completion_params": {"temperature": 0.7}, | |||
| "mode": "chat", | |||
| "name": "gpt-4o", | |||
| "provider": "openai", | |||
| }, | |||
| "prompt_template": [ | |||
| {"role": "system", "text": "say hi"}, | |||
| {"role": "user", "text": "{{#start.query#}}"}, | |||
| ], | |||
| "vision": {"configs": {"detail": "high"}, "enabled": False}, | |||
| }, | |||
| "id": "llm1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| "title": "llm2", | |||
| "context": {"enabled": False, "variable_selector": []}, | |||
| "model": { | |||
| "completion_params": {"temperature": 0.7}, | |||
| "mode": "chat", | |||
| "name": "gpt-4o", | |||
| "provider": "openai", | |||
| }, | |||
| "prompt_template": [ | |||
| {"role": "system", "text": "say bye"}, | |||
| {"role": "user", "text": "{{#start.query#}}"}, | |||
| ], | |||
| "vision": {"configs": {"detail": "high"}, "enabled": False}, | |||
| }, | |||
| "id": "llm2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| "title": "llm3", | |||
| "context": {"enabled": False, "variable_selector": []}, | |||
| "model": { | |||
| "completion_params": {"temperature": 0.7}, | |||
| "mode": "chat", | |||
| "name": "gpt-4o", | |||
| "provider": "openai", | |||
| }, | |||
| "prompt_template": [ | |||
| {"role": "system", "text": "say good morning"}, | |||
| {"role": "user", "text": "{{#start.query#}}"}, | |||
| ], | |||
| "vision": {"configs": {"detail": "high"}, "enabled": False}, | |||
| }, | |||
| "id": "llm3", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "end", | |||
| "title": "end1", | |||
| "outputs": [ | |||
| {"value_selector": ["llm2", "text"], "variable": "result2"}, | |||
| {"value_selector": ["start", "query"], "variable": "query"}, | |||
| ], | |||
| }, | |||
| "id": "end1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "end", | |||
| "title": "end2", | |||
| "outputs": [ | |||
| {"value_selector": ["llm1", "text"], "variable": "result1"}, | |||
| {"value_selector": ["llm3", "text"], "variable": "result3"}, | |||
| ], | |||
| }, | |||
| "id": "end2", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| variable_pool = VariablePool( | |||
| system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} | |||
| ) | |||
| graph_engine = GraphEngine( | |||
| tenant_id="111", | |||
| app_id="222", | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="333", | |||
| graph_config=graph_config, | |||
| user_id="444", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| call_depth=0, | |||
| graph=graph, | |||
| variable_pool=variable_pool, | |||
| max_execution_steps=500, | |||
| max_execution_time=1200, | |||
| ) | |||
| def llm_generator(self): | |||
| contents = ["hi", "bye", "good morning"] | |||
| yield RunStreamChunkEvent( | |||
| chunk_content=contents[int(self.node_id[-1]) - 1], from_variable_selector=[self.node_id, "text"] | |||
| ) | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs={}, | |||
| process_data={}, | |||
| outputs={}, | |||
| metadata={ | |||
| NodeRunMetadataKey.TOTAL_TOKENS: 1, | |||
| NodeRunMetadataKey.TOTAL_PRICE: 1, | |||
| NodeRunMetadataKey.CURRENCY: "USD", | |||
| }, | |||
| ) | |||
| ) | |||
| # print("") | |||
| with patch.object(LLMNode, "_run", new=llm_generator): | |||
| items = [] | |||
| generator = graph_engine.run() | |||
| for item in generator: | |||
| # print(type(item), item) | |||
| items.append(item) | |||
| if isinstance(item, NodeRunSucceededEvent): | |||
| assert item.route_node_state.status == RouteNodeState.Status.SUCCESS | |||
| assert not isinstance(item, NodeRunFailedEvent) | |||
| assert not isinstance(item, GraphRunFailedEvent) | |||
| if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]: | |||
| assert item.parallel_id is not None | |||
| assert len(items) == 18 | |||
| assert isinstance(items[0], GraphRunStartedEvent) | |||
| assert isinstance(items[1], NodeRunStartedEvent) | |||
| assert items[1].route_node_state.node_id == "start" | |||
| assert isinstance(items[2], NodeRunSucceededEvent) | |||
| assert items[2].route_node_state.node_id == "start" | |||
| @patch("extensions.ext_database.db.session.remove") | |||
| @patch("extensions.ext_database.db.session.close") | |||
| def test_run_parallel_in_chatflow(mock_close, mock_remove): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "1", | |||
| "source": "start", | |||
| "target": "answer1", | |||
| }, | |||
| { | |||
| "id": "2", | |||
| "source": "answer1", | |||
| "target": "answer2", | |||
| }, | |||
| { | |||
| "id": "3", | |||
| "source": "answer1", | |||
| "target": "answer3", | |||
| }, | |||
| { | |||
| "id": "4", | |||
| "source": "answer2", | |||
| "target": "answer4", | |||
| }, | |||
| { | |||
| "id": "5", | |||
| "source": "answer3", | |||
| "target": "answer5", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"type": "start", "title": "start"}, "id": "start"}, | |||
| {"data": {"type": "answer", "title": "answer1", "answer": "1"}, "id": "answer1"}, | |||
| { | |||
| "data": {"type": "answer", "title": "answer2", "answer": "2"}, | |||
| "id": "answer2", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer3", "answer": "3"}, | |||
| "id": "answer3", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer4", "answer": "4"}, | |||
| "id": "answer4", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer5", "answer": "5"}, | |||
| "id": "answer5", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| variable_pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "what's the weather in SF", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "aaa", | |||
| }, | |||
| user_inputs={}, | |||
| ) | |||
| graph_engine = GraphEngine( | |||
| tenant_id="111", | |||
| app_id="222", | |||
| workflow_type=WorkflowType.CHAT, | |||
| workflow_id="333", | |||
| graph_config=graph_config, | |||
| user_id="444", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| call_depth=0, | |||
| graph=graph, | |||
| variable_pool=variable_pool, | |||
| max_execution_steps=500, | |||
| max_execution_time=1200, | |||
| ) | |||
| # print("") | |||
| items = [] | |||
| generator = graph_engine.run() | |||
| for item in generator: | |||
| # print(type(item), item) | |||
| items.append(item) | |||
| if isinstance(item, NodeRunSucceededEvent): | |||
| assert item.route_node_state.status == RouteNodeState.Status.SUCCESS | |||
| assert not isinstance(item, NodeRunFailedEvent) | |||
| assert not isinstance(item, GraphRunFailedEvent) | |||
| if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [ | |||
| "answer2", | |||
| "answer3", | |||
| "answer4", | |||
| "answer5", | |||
| ]: | |||
| assert item.parallel_id is not None | |||
| assert len(items) == 23 | |||
| assert isinstance(items[0], GraphRunStartedEvent) | |||
| assert isinstance(items[1], NodeRunStartedEvent) | |||
| assert items[1].route_node_state.node_id == "start" | |||
| assert isinstance(items[2], NodeRunSucceededEvent) | |||
| assert items[2].route_node_state.node_id == "start" | |||
| @patch("extensions.ext_database.db.session.remove") | |||
| @patch("extensions.ext_database.db.session.close") | |||
| def test_run_branch(mock_close, mock_remove): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "1", | |||
| "source": "start", | |||
| "target": "if-else-1", | |||
| }, | |||
| { | |||
| "id": "2", | |||
| "source": "if-else-1", | |||
| "sourceHandle": "true", | |||
| "target": "answer-1", | |||
| }, | |||
| { | |||
| "id": "3", | |||
| "source": "if-else-1", | |||
| "sourceHandle": "false", | |||
| "target": "if-else-2", | |||
| }, | |||
| { | |||
| "id": "4", | |||
| "source": "if-else-2", | |||
| "sourceHandle": "true", | |||
| "target": "answer-2", | |||
| }, | |||
| { | |||
| "id": "5", | |||
| "source": "if-else-2", | |||
| "sourceHandle": "false", | |||
| "target": "answer-3", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| { | |||
| "data": { | |||
| "title": "Start", | |||
| "type": "start", | |||
| "variables": [ | |||
| { | |||
| "label": "uid", | |||
| "max_length": 48, | |||
| "options": [], | |||
| "required": True, | |||
| "type": "text-input", | |||
| "variable": "uid", | |||
| } | |||
| ], | |||
| }, | |||
| "id": "start", | |||
| }, | |||
| { | |||
| "data": {"answer": "1 {{#start.uid#}}", "title": "Answer", "type": "answer", "variables": []}, | |||
| "id": "answer-1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "cases": [ | |||
| { | |||
| "case_id": "true", | |||
| "conditions": [ | |||
| { | |||
| "comparison_operator": "contains", | |||
| "id": "b0f02473-08b6-4a81-af91-15345dcb2ec8", | |||
| "value": "hi", | |||
| "varType": "string", | |||
| "variable_selector": ["sys", "query"], | |||
| } | |||
| ], | |||
| "id": "true", | |||
| "logical_operator": "and", | |||
| } | |||
| ], | |||
| "desc": "", | |||
| "title": "IF/ELSE", | |||
| "type": "if-else", | |||
| }, | |||
| "id": "if-else-1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "cases": [ | |||
| { | |||
| "case_id": "true", | |||
| "conditions": [ | |||
| { | |||
| "comparison_operator": "contains", | |||
| "id": "ae895199-5608-433b-b5f0-0997ae1431e4", | |||
| "value": "takatost", | |||
| "varType": "string", | |||
| "variable_selector": ["sys", "query"], | |||
| } | |||
| ], | |||
| "id": "true", | |||
| "logical_operator": "and", | |||
| } | |||
| ], | |||
| "title": "IF/ELSE 2", | |||
| "type": "if-else", | |||
| }, | |||
| "id": "if-else-2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "answer": "2", | |||
| "title": "Answer 2", | |||
| "type": "answer", | |||
| }, | |||
| "id": "answer-2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "answer": "3", | |||
| "title": "Answer 3", | |||
| "type": "answer", | |||
| }, | |||
| "id": "answer-3", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| variable_pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "hi", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "aaa", | |||
| }, | |||
| user_inputs={"uid": "takato"}, | |||
| ) | |||
| graph_engine = GraphEngine( | |||
| tenant_id="111", | |||
| app_id="222", | |||
| workflow_type=WorkflowType.CHAT, | |||
| workflow_id="333", | |||
| graph_config=graph_config, | |||
| user_id="444", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| call_depth=0, | |||
| graph=graph, | |||
| variable_pool=variable_pool, | |||
| max_execution_steps=500, | |||
| max_execution_time=1200, | |||
| ) | |||
| # print("") | |||
| items = [] | |||
| generator = graph_engine.run() | |||
| for item in generator: | |||
| # print(type(item), item) | |||
| items.append(item) | |||
| assert len(items) == 10 | |||
| assert items[3].route_node_state.node_id == "if-else-1" | |||
| assert items[4].route_node_state.node_id == "if-else-1" | |||
| assert isinstance(items[5], NodeRunStreamChunkEvent) | |||
| assert items[5].chunk_content == "1 " | |||
| assert isinstance(items[6], NodeRunStreamChunkEvent) | |||
| assert items[6].chunk_content == "takato" | |||
| assert items[7].route_node_state.node_id == "answer-1" | |||
| assert items[8].route_node_state.node_id == "answer-1" | |||
| assert items[8].route_node_state.node_run_result.outputs["answer"] == "1 takato" | |||
| assert isinstance(items[9], GraphRunSucceededEvent) | |||
| # print(graph_engine.graph_runtime_state.model_dump_json(indent=2)) | |||
| @@ -0,0 +1,82 @@ | |||
| import time | |||
| import uuid | |||
| from unittest.mock import MagicMock | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import UserFrom | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes.answer.answer_node import AnswerNode | |||
| from extensions.ext_database import db | |||
| from models.workflow import WorkflowNodeExecutionStatus, WorkflowType | |||
| def test_execute_answer(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-llm-target", | |||
| "source": "start", | |||
| "target": "llm", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"type": "start"}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| pool.add(["start", "weather"], "sunny") | |||
| pool.add(["llm", "text"], "You are a helpful AI.") | |||
| node = AnswerNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph=graph, | |||
| graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), | |||
| config={ | |||
| "id": "answer", | |||
| "data": { | |||
| "title": "123", | |||
| "type": "answer", | |||
| "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", | |||
| }, | |||
| }, | |||
| ) | |||
| # Mock db.session.close() | |||
| db.session.close = MagicMock() | |||
| # execute node | |||
| result = node._run() | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." | |||
| @@ -0,0 +1,109 @@ | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter | |||
| def test_init(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-llm1-target", | |||
| "source": "start", | |||
| "target": "llm1", | |||
| }, | |||
| { | |||
| "id": "start-source-llm2-target", | |||
| "source": "start", | |||
| "target": "llm2", | |||
| }, | |||
| { | |||
| "id": "start-source-llm3-target", | |||
| "source": "start", | |||
| "target": "llm3", | |||
| }, | |||
| { | |||
| "id": "llm3-source-llm4-target", | |||
| "source": "llm3", | |||
| "target": "llm4", | |||
| }, | |||
| { | |||
| "id": "llm3-source-llm5-target", | |||
| "source": "llm3", | |||
| "target": "llm5", | |||
| }, | |||
| { | |||
| "id": "llm4-source-answer2-target", | |||
| "source": "llm4", | |||
| "target": "answer2", | |||
| }, | |||
| { | |||
| "id": "llm5-source-answer-target", | |||
| "source": "llm5", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "answer2-source-answer-target", | |||
| "source": "answer2", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "llm2-source-answer-target", | |||
| "source": "llm2", | |||
| "target": "answer", | |||
| }, | |||
| { | |||
| "id": "llm1-source-answer-target", | |||
| "source": "llm1", | |||
| "target": "answer", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"type": "start"}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm3", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm4", | |||
| }, | |||
| { | |||
| "data": { | |||
| "type": "llm", | |||
| }, | |||
| "id": "llm5", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer", "answer": "1{{#llm2.text#}}2"}, | |||
| "id": "answer", | |||
| }, | |||
| { | |||
| "data": {"type": "answer", "title": "answer2", "answer": "1{{#llm3.text#}}2"}, | |||
| "id": "answer2", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| answer_stream_generate_route = AnswerStreamGeneratorRouter.init( | |||
| node_id_config_mapping=graph.node_id_config_mapping, reverse_edge_mapping=graph.reverse_edge_mapping | |||
| ) | |||
| assert answer_stream_generate_route.answer_dependencies["answer"] == ["answer2"] | |||
| assert answer_stream_generate_route.answer_dependencies["answer2"] == [] | |||