Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>tags/1.4.1
| from typing import cast | |||||
| from flask_login import current_user | |||||
| from flask_restful import Resource, marshal_with, reqparse | from flask_restful import Resource, marshal_with, reqparse | ||||
| from flask_restful.inputs import int_range | from flask_restful.inputs import int_range | ||||
| ) | ) | ||||
| from libs.helper import uuid_value | from libs.helper import uuid_value | ||||
| from libs.login import login_required | from libs.login import login_required | ||||
| from models import App | |||||
| from models.model import AppMode | |||||
| from models import Account, App, AppMode, EndUser | |||||
| from services.workflow_run_service import WorkflowRunService | from services.workflow_run_service import WorkflowRunService | ||||
| run_id = str(run_id) | run_id = str(run_id) | ||||
| workflow_run_service = WorkflowRunService() | workflow_run_service = WorkflowRunService() | ||||
| node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) | |||||
| user = cast("Account | EndUser", current_user) | |||||
| node_executions = workflow_run_service.get_workflow_run_node_executions( | |||||
| app_model=app_model, | |||||
| run_id=run_id, | |||||
| user=user, | |||||
| ) | |||||
| return {"data": node_executions} | return {"data": node_executions} | ||||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| from models.account import Account | |||||
| from models.model import App, Conversation, EndUser, Message | |||||
| from models.workflow import Workflow | |||||
| from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom | |||||
| from services.conversation_service import ConversationService | from services.conversation_service import ConversationService | ||||
| from services.errors.message import MessageNotExistsError | from services.errors.message import MessageNotExistsError | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=session_factory, | session_factory=session_factory, | ||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| user=user, | |||||
| app_id=application_generate_entity.app_config.app_id, | app_id=application_generate_entity.app_config.app_id, | ||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=session_factory, | session_factory=session_factory, | ||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| user=user, | |||||
| app_id=application_generate_entity.app_config.app_id, | app_id=application_generate_entity.app_config.app_id, | ||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=session_factory, | session_factory=session_factory, | ||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| user=user, | |||||
| app_id=application_generate_entity.app_config.app_id, | app_id=application_generate_entity.app_config.app_id, | ||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( |
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models import Conversation, EndUser, Message, MessageFile | from models import Conversation, EndUser, Message, MessageFile | ||||
| from models.account import Account | from models.account import Account | ||||
| from models.enums import CreatedByRole | |||||
| from models.enums import CreatorUserRole | |||||
| from models.workflow import ( | from models.workflow import ( | ||||
| Workflow, | Workflow, | ||||
| WorkflowRunStatus, | WorkflowRunStatus, | ||||
| if isinstance(user, EndUser): | if isinstance(user, EndUser): | ||||
| self._user_id = user.id | self._user_id = user.id | ||||
| user_session_id = user.session_id | user_session_id = user.session_id | ||||
| self._created_by_role = CreatedByRole.END_USER | |||||
| self._created_by_role = CreatorUserRole.END_USER | |||||
| elif isinstance(user, Account): | elif isinstance(user, Account): | ||||
| self._user_id = user.id | self._user_id = user.id | ||||
| user_session_id = user.id | user_session_id = user.id | ||||
| self._created_by_role = CreatedByRole.ACCOUNT | |||||
| self._created_by_role = CreatorUserRole.ACCOUNT | |||||
| else: | else: | ||||
| raise NotImplementedError(f"User type not supported: {type(user)}") | raise NotImplementedError(f"User type not supported: {type(user)}") | ||||
| url=file["remote_url"], | url=file["remote_url"], | ||||
| belongs_to="assistant", | belongs_to="assistant", | ||||
| upload_file_id=file["related_id"], | upload_file_id=file["related_id"], | ||||
| created_by_role=CreatedByRole.ACCOUNT | |||||
| created_by_role=CreatorUserRole.ACCOUNT | |||||
| if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | ||||
| else CreatedByRole.END_USER, | |||||
| else CreatorUserRole.END_USER, | |||||
| created_by=message.from_account_id or message.from_end_user_id or "", | created_by=message.from_account_id or message.from_end_user_id or "", | ||||
| ) | ) | ||||
| for file in self._recorded_files | for file in self._recorded_files |
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | from core.prompt.utils.prompt_template_parser import PromptTemplateParser | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models import Account | from models import Account | ||||
| from models.enums import CreatedByRole | |||||
| from models.enums import CreatorUserRole | |||||
| from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile | from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile | ||||
| from services.errors.app_model_config import AppModelConfigBrokenError | from services.errors.app_model_config import AppModelConfigBrokenError | ||||
| from services.errors.conversation import ConversationNotExistsError | from services.errors.conversation import ConversationNotExistsError | ||||
| belongs_to="user", | belongs_to="user", | ||||
| url=file.remote_url, | url=file.remote_url, | ||||
| upload_file_id=file.related_id, | upload_file_id=file.related_id, | ||||
| created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER), | |||||
| created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), | |||||
| created_by=account_id or end_user_id or "", | created_by=account_id or end_user_id or "", | ||||
| ) | ) | ||||
| db.session.add(message_file) | db.session.add(message_file) |
| from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline | from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| from models import Account, App, EndUser, Workflow | |||||
| from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| # Create workflow node execution repository | # Create workflow node execution repository | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=session_factory, | session_factory=session_factory, | ||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| user=user, | |||||
| app_id=application_generate_entity.app_config.app_id, | app_id=application_generate_entity.app_config.app_id, | ||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( | ||||
| # Create workflow node execution repository | # Create workflow node execution repository | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=session_factory, | session_factory=session_factory, | ||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| user=user, | |||||
| app_id=application_generate_entity.app_config.app_id, | app_id=application_generate_entity.app_config.app_id, | ||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( | ||||
| # Create workflow node execution repository | # Create workflow node execution repository | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=session_factory, | session_factory=session_factory, | ||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| user=user, | |||||
| app_id=application_generate_entity.app_config.app_id, | app_id=application_generate_entity.app_config.app_id, | ||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( |
| from core.model_runtime.entities.llm_entities import LLMResult | from core.model_runtime.entities.llm_entities import LLMResult | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey | |||||
| from models.workflow import WorkflowNodeExecutionStatus | from models.workflow import WorkflowNodeExecutionStatus | ||||
| title: str | title: str | ||||
| index: int | index: int | ||||
| predecessor_node_id: Optional[str] = None | predecessor_node_id: Optional[str] = None | ||||
| inputs: Optional[dict] = None | |||||
| inputs: Optional[Mapping[str, Any]] = None | |||||
| created_at: int | created_at: int | ||||
| extras: dict = {} | extras: dict = {} | ||||
| parallel_id: Optional[str] = None | parallel_id: Optional[str] = None | ||||
| title: str | title: str | ||||
| index: int | index: int | ||||
| predecessor_node_id: Optional[str] = None | predecessor_node_id: Optional[str] = None | ||||
| inputs: Optional[dict] = None | |||||
| process_data: Optional[dict] = None | |||||
| outputs: Optional[dict] = None | |||||
| inputs: Optional[Mapping[str, Any]] = None | |||||
| process_data: Optional[Mapping[str, Any]] = None | |||||
| outputs: Optional[Mapping[str, Any]] = None | |||||
| status: str | status: str | ||||
| error: Optional[str] = None | error: Optional[str] = None | ||||
| elapsed_time: float | elapsed_time: float | ||||
| execution_metadata: Optional[dict] = None | |||||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||||
| created_at: int | created_at: int | ||||
| finished_at: int | finished_at: int | ||||
| files: Optional[Sequence[Mapping[str, Any]]] = [] | files: Optional[Sequence[Mapping[str, Any]]] = [] | ||||
| title: str | title: str | ||||
| index: int | index: int | ||||
| predecessor_node_id: Optional[str] = None | predecessor_node_id: Optional[str] = None | ||||
| inputs: Optional[dict] = None | |||||
| process_data: Optional[dict] = None | |||||
| outputs: Optional[dict] = None | |||||
| inputs: Optional[Mapping[str, Any]] = None | |||||
| process_data: Optional[Mapping[str, Any]] = None | |||||
| outputs: Optional[Mapping[str, Any]] = None | |||||
| status: str | status: str | ||||
| error: Optional[str] = None | error: Optional[str] = None | ||||
| elapsed_time: float | elapsed_time: float | ||||
| execution_metadata: Optional[dict] = None | |||||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||||
| created_at: int | created_at: int | ||||
| finished_at: int | finished_at: int | ||||
| files: Optional[Sequence[Mapping[str, Any]]] = [] | files: Optional[Sequence[Mapping[str, Any]]] = [] |
| from collections.abc import Mapping | |||||
| from datetime import datetime | from datetime import datetime | ||||
| from enum import StrEnum | from enum import StrEnum | ||||
| from typing import Any, Optional, Union | from typing import Any, Optional, Union | ||||
| description="The status message of the span. Additional field for context of the event. E.g. the error " | description="The status message of the span. Additional field for context of the event. E.g. the error " | ||||
| "message of an error event.", | "message of an error event.", | ||||
| ) | ) | ||||
| input: Optional[Union[str, dict[str, Any], list, None]] = Field( | |||||
| input: Optional[Union[str, Mapping[str, Any], list, None]] = Field( | |||||
| default=None, description="The input of the span. Can be any JSON object." | default=None, description="The input of the span. Can be any JSON object." | ||||
| ) | ) | ||||
| output: Optional[Union[str, dict[str, Any], list, None]] = Field( | |||||
| output: Optional[Union[str, Mapping[str, Any], list, None]] = Field( | |||||
| default=None, description="The output of the span. Can be any JSON object." | default=None, description="The output of the span. Can be any JSON object." | ||||
| ) | ) | ||||
| version: Optional[str] = Field( | version: Optional[str] = Field( |
| import json | |||||
| import logging | import logging | ||||
| import os | import os | ||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||
| from typing import Optional | from typing import Optional | ||||
| from langfuse import Langfuse # type: ignore | from langfuse import Langfuse # type: ignore | ||||
| from sqlalchemy.orm import sessionmaker | |||||
| from sqlalchemy.orm import Session, sessionmaker | |||||
| from core.ops.base_trace_instance import BaseTraceInstance | from core.ops.base_trace_instance import BaseTraceInstance | ||||
| from core.ops.entities.config_entity import LangfuseConfig | from core.ops.entities.config_entity import LangfuseConfig | ||||
| ) | ) | ||||
| from core.ops.utils import filter_none_values | from core.ops.utils import filter_none_values | ||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | ||||
| from core.workflow.nodes.enums import NodeType | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import EndUser | |||||
| from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| # through workflow_run_id get all_nodes_execution using repository | # through workflow_run_id get all_nodes_execution using repository | ||||
| session_factory = sessionmaker(bind=db.engine) | session_factory = sessionmaker(bind=db.engine) | ||||
| # Find the app's creator account | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| # Get the app to find its creator | |||||
| app_id = trace_info.metadata.get("app_id") | |||||
| if not app_id: | |||||
| raise ValueError("No app_id found in trace_info metadata") | |||||
| app = session.query(App).filter(App.id == app_id).first() | |||||
| if not app: | |||||
| raise ValueError(f"App with id {app_id} not found") | |||||
| if not app.created_by: | |||||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||||
| if not service_account: | |||||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=session_factory, tenant_id=trace_info.tenant_id | |||||
| session_factory=session_factory, | |||||
| user=service_account, | |||||
| app_id=trace_info.metadata.get("app_id"), | |||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||||
| ) | ) | ||||
| # Get all executions for this workflow run | # Get all executions for this workflow run | ||||
| for node_execution in workflow_node_executions: | for node_execution in workflow_node_executions: | ||||
| node_execution_id = node_execution.id | node_execution_id = node_execution.id | ||||
| tenant_id = node_execution.tenant_id | |||||
| app_id = node_execution.app_id | |||||
| tenant_id = trace_info.tenant_id # Use from trace_info instead | |||||
| app_id = trace_info.metadata.get("app_id") # Use from trace_info instead | |||||
| node_name = node_execution.title | node_name = node_execution.title | ||||
| node_type = node_execution.node_type | node_type = node_execution.node_type | ||||
| status = node_execution.status | status = node_execution.status | ||||
| if node_type == "llm": | |||||
| inputs = ( | |||||
| json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} | |||||
| ) | |||||
| if node_type == NodeType.LLM: | |||||
| inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} | |||||
| else: | else: | ||||
| inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} | |||||
| outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} | |||||
| inputs = node_execution.inputs if node_execution.inputs else {} | |||||
| outputs = node_execution.outputs if node_execution.outputs else {} | |||||
| created_at = node_execution.created_at or datetime.now() | created_at = node_execution.created_at or datetime.now() | ||||
| elapsed_time = node_execution.elapsed_time | elapsed_time = node_execution.elapsed_time | ||||
| finished_at = created_at + timedelta(seconds=elapsed_time) | finished_at = created_at + timedelta(seconds=elapsed_time) | ||||
| metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} | |||||
| execution_metadata = node_execution.metadata if node_execution.metadata else {} | |||||
| metadata = {str(k): v for k, v in execution_metadata.items()} | |||||
| metadata.update( | metadata.update( | ||||
| { | { | ||||
| "workflow_run_id": trace_info.workflow_run_id, | "workflow_run_id": trace_info.workflow_run_id, | ||||
| "status": status, | "status": status, | ||||
| } | } | ||||
| ) | ) | ||||
| process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} | |||||
| process_data = node_execution.process_data if node_execution.process_data else {} | |||||
| model_provider = process_data.get("model_provider", None) | model_provider = process_data.get("model_provider", None) | ||||
| model_name = process_data.get("model_name", None) | model_name = process_data.get("model_name", None) | ||||
| if model_provider is not None and model_name is not None: | if model_provider is not None and model_name is not None: |
| from collections.abc import Mapping | |||||
| from datetime import datetime | from datetime import datetime | ||||
| from enum import StrEnum | from enum import StrEnum | ||||
| from typing import Any, Optional, Union | from typing import Any, Optional, Union | ||||
| class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): | class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): | ||||
| name: Optional[str] = Field(..., description="Name of the run") | name: Optional[str] = Field(..., description="Name of the run") | ||||
| inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run") | |||||
| outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run") | |||||
| inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run") | |||||
| outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run") | |||||
| run_type: LangSmithRunType = Field(..., description="Type of the run") | run_type: LangSmithRunType = Field(..., description="Type of the run") | ||||
| start_time: Optional[datetime | str] = Field(None, description="Start time of the run") | start_time: Optional[datetime | str] = Field(None, description="Start time of the run") | ||||
| end_time: Optional[datetime | str] = Field(None, description="End time of the run") | end_time: Optional[datetime | str] = Field(None, description="End time of the run") |
| import json | |||||
| import logging | import logging | ||||
| import os | import os | ||||
| import uuid | import uuid | ||||
| from langsmith import Client | from langsmith import Client | ||||
| from langsmith.schemas import RunBase | from langsmith.schemas import RunBase | ||||
| from sqlalchemy.orm import sessionmaker | |||||
| from sqlalchemy.orm import Session, sessionmaker | |||||
| from core.ops.base_trace_instance import BaseTraceInstance | from core.ops.base_trace_instance import BaseTraceInstance | ||||
| from core.ops.entities.config_entity import LangSmithConfig | from core.ops.entities.config_entity import LangSmithConfig | ||||
| ) | ) | ||||
| from core.ops.utils import filter_none_values, generate_dotted_order | from core.ops.utils import filter_none_values, generate_dotted_order | ||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | ||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||||
| from core.workflow.nodes.enums import NodeType | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import EndUser, MessageFile | |||||
| from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| # through workflow_run_id get all_nodes_execution using repository | # through workflow_run_id get all_nodes_execution using repository | ||||
| session_factory = sessionmaker(bind=db.engine) | session_factory = sessionmaker(bind=db.engine) | ||||
| # Find the app's creator account | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| # Get the app to find its creator | |||||
| app_id = trace_info.metadata.get("app_id") | |||||
| if not app_id: | |||||
| raise ValueError("No app_id found in trace_info metadata") | |||||
| app = session.query(App).filter(App.id == app_id).first() | |||||
| if not app: | |||||
| raise ValueError(f"App with id {app_id} not found") | |||||
| if not app.created_by: | |||||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||||
| if not service_account: | |||||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") | |||||
| session_factory=session_factory, | |||||
| user=service_account, | |||||
| app_id=trace_info.metadata.get("app_id"), | |||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||||
| ) | ) | ||||
| # Get all executions for this workflow run | # Get all executions for this workflow run | ||||
| for node_execution in workflow_node_executions: | for node_execution in workflow_node_executions: | ||||
| node_execution_id = node_execution.id | node_execution_id = node_execution.id | ||||
| tenant_id = node_execution.tenant_id | |||||
| app_id = node_execution.app_id | |||||
| tenant_id = trace_info.tenant_id # Use from trace_info instead | |||||
| app_id = trace_info.metadata.get("app_id") # Use from trace_info instead | |||||
| node_name = node_execution.title | node_name = node_execution.title | ||||
| node_type = node_execution.node_type | node_type = node_execution.node_type | ||||
| status = node_execution.status | status = node_execution.status | ||||
| if node_type == "llm": | |||||
| inputs = ( | |||||
| json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} | |||||
| ) | |||||
| if node_type == NodeType.LLM: | |||||
| inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} | |||||
| else: | else: | ||||
| inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} | |||||
| outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} | |||||
| inputs = node_execution.inputs if node_execution.inputs else {} | |||||
| outputs = node_execution.outputs if node_execution.outputs else {} | |||||
| created_at = node_execution.created_at or datetime.now() | created_at = node_execution.created_at or datetime.now() | ||||
| elapsed_time = node_execution.elapsed_time | elapsed_time = node_execution.elapsed_time | ||||
| finished_at = created_at + timedelta(seconds=elapsed_time) | finished_at = created_at + timedelta(seconds=elapsed_time) | ||||
| execution_metadata = ( | |||||
| json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} | |||||
| ) | |||||
| node_total_tokens = execution_metadata.get("total_tokens", 0) | |||||
| metadata = execution_metadata.copy() | |||||
| execution_metadata = node_execution.metadata if node_execution.metadata else {} | |||||
| node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 | |||||
| metadata = {str(key): value for key, value in execution_metadata.items()} | |||||
| metadata.update( | metadata.update( | ||||
| { | { | ||||
| "workflow_run_id": trace_info.workflow_run_id, | "workflow_run_id": trace_info.workflow_run_id, | ||||
| } | } | ||||
| ) | ) | ||||
| process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} | |||||
| process_data = node_execution.process_data if node_execution.process_data else {} | |||||
| if process_data and process_data.get("model_mode") == "chat": | if process_data and process_data.get("model_mode") == "chat": | ||||
| run_type = LangSmithRunType.llm | run_type = LangSmithRunType.llm | ||||
| "ls_model_name": process_data.get("model_name", ""), | "ls_model_name": process_data.get("model_name", ""), | ||||
| } | } | ||||
| ) | ) | ||||
| elif node_type == "knowledge-retrieval": | |||||
| elif node_type == NodeType.KNOWLEDGE_RETRIEVAL: | |||||
| run_type = LangSmithRunType.retriever | run_type = LangSmithRunType.retriever | ||||
| else: | else: | ||||
| run_type = LangSmithRunType.tool | run_type = LangSmithRunType.tool |
| import json | |||||
| import logging | import logging | ||||
| import os | import os | ||||
| import uuid | import uuid | ||||
| from opik import Opik, Trace | from opik import Opik, Trace | ||||
| from opik.id_helpers import uuid4_to_uuid7 | from opik.id_helpers import uuid4_to_uuid7 | ||||
| from sqlalchemy.orm import sessionmaker | |||||
| from sqlalchemy.orm import Session, sessionmaker | |||||
| from core.ops.base_trace_instance import BaseTraceInstance | from core.ops.base_trace_instance import BaseTraceInstance | ||||
| from core.ops.entities.config_entity import OpikConfig | from core.ops.entities.config_entity import OpikConfig | ||||
| WorkflowTraceInfo, | WorkflowTraceInfo, | ||||
| ) | ) | ||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | ||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||||
| from core.workflow.nodes.enums import NodeType | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import EndUser, MessageFile | |||||
| from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| # through workflow_run_id get all_nodes_execution using repository | # through workflow_run_id get all_nodes_execution using repository | ||||
| session_factory = sessionmaker(bind=db.engine) | session_factory = sessionmaker(bind=db.engine) | ||||
| # Find the app's creator account | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| # Get the app to find its creator | |||||
| app_id = trace_info.metadata.get("app_id") | |||||
| if not app_id: | |||||
| raise ValueError("No app_id found in trace_info metadata") | |||||
| app = session.query(App).filter(App.id == app_id).first() | |||||
| if not app: | |||||
| raise ValueError(f"App with id {app_id} not found") | |||||
| if not app.created_by: | |||||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||||
| if not service_account: | |||||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") | |||||
| session_factory=session_factory, | |||||
| user=service_account, | |||||
| app_id=trace_info.metadata.get("app_id"), | |||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||||
| ) | ) | ||||
| # Get all executions for this workflow run | # Get all executions for this workflow run | ||||
| for node_execution in workflow_node_executions: | for node_execution in workflow_node_executions: | ||||
| node_execution_id = node_execution.id | node_execution_id = node_execution.id | ||||
| tenant_id = node_execution.tenant_id | |||||
| app_id = node_execution.app_id | |||||
| tenant_id = trace_info.tenant_id # Use from trace_info instead | |||||
| app_id = trace_info.metadata.get("app_id") # Use from trace_info instead | |||||
| node_name = node_execution.title | node_name = node_execution.title | ||||
| node_type = node_execution.node_type | node_type = node_execution.node_type | ||||
| status = node_execution.status | status = node_execution.status | ||||
| if node_type == "llm": | |||||
| inputs = ( | |||||
| json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} | |||||
| ) | |||||
| if node_type == NodeType.LLM: | |||||
| inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} | |||||
| else: | else: | ||||
| inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} | |||||
| outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} | |||||
| inputs = node_execution.inputs if node_execution.inputs else {} | |||||
| outputs = node_execution.outputs if node_execution.outputs else {} | |||||
| created_at = node_execution.created_at or datetime.now() | created_at = node_execution.created_at or datetime.now() | ||||
| elapsed_time = node_execution.elapsed_time | elapsed_time = node_execution.elapsed_time | ||||
| finished_at = created_at + timedelta(seconds=elapsed_time) | finished_at = created_at + timedelta(seconds=elapsed_time) | ||||
| execution_metadata = ( | |||||
| json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} | |||||
| ) | |||||
| metadata = execution_metadata.copy() | |||||
| execution_metadata = node_execution.metadata if node_execution.metadata else {} | |||||
| metadata = {str(k): v for k, v in execution_metadata.items()} | |||||
| metadata.update( | metadata.update( | ||||
| { | { | ||||
| "workflow_run_id": trace_info.workflow_run_id, | "workflow_run_id": trace_info.workflow_run_id, | ||||
| } | } | ||||
| ) | ) | ||||
| process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} | |||||
| process_data = node_execution.process_data if node_execution.process_data else {} | |||||
| provider = None | provider = None | ||||
| model = None | model = None | ||||
| parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id | parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id | ||||
| if not total_tokens: | if not total_tokens: | ||||
| total_tokens = execution_metadata.get("total_tokens", 0) | |||||
| total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 | |||||
| span_data = { | span_data = { | ||||
| "trace_id": opik_trace_id, | "trace_id": opik_trace_id, |
| from collections.abc import Mapping | |||||
| from typing import Any, Optional, Union | from typing import Any, Optional, Union | ||||
| from pydantic import BaseModel, Field, field_validator | from pydantic import BaseModel, Field, field_validator | ||||
| class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): | class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): | ||||
| id: str = Field(..., description="ID of the trace") | id: str = Field(..., description="ID of the trace") | ||||
| op: str = Field(..., description="Name of the operation") | op: str = Field(..., description="Name of the operation") | ||||
| inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace") | |||||
| outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace") | |||||
| inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace") | |||||
| outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace") | |||||
| attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( | attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( | ||||
| None, description="Metadata and attributes associated with trace" | None, description="Metadata and attributes associated with trace" | ||||
| ) | ) |
| import json | |||||
| import logging | import logging | ||||
| import os | import os | ||||
| import uuid | import uuid | ||||
| import wandb | import wandb | ||||
| import weave | import weave | ||||
| from sqlalchemy.orm import Session, sessionmaker | |||||
| from core.ops.base_trace_instance import BaseTraceInstance | from core.ops.base_trace_instance import BaseTraceInstance | ||||
| from core.ops.entities.config_entity import WeaveConfig | from core.ops.entities.config_entity import WeaveConfig | ||||
| WorkflowTraceInfo, | WorkflowTraceInfo, | ||||
| ) | ) | ||||
| from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel | from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel | ||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||||
| from core.workflow.nodes.enums import NodeType | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import EndUser, MessageFile | |||||
| from models.workflow import WorkflowNodeExecution | |||||
| from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| self.start_call(workflow_run, parent_run_id=trace_info.message_id) | self.start_call(workflow_run, parent_run_id=trace_info.message_id) | ||||
| # through workflow_run_id get all_nodes_execution | |||||
| workflow_nodes_execution_id_records = ( | |||||
| db.session.query(WorkflowNodeExecution.id) | |||||
| .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) | |||||
| .all() | |||||
| # through workflow_run_id get all_nodes_execution using repository | |||||
| session_factory = sessionmaker(bind=db.engine) | |||||
| # Find the app's creator account | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| # Get the app to find its creator | |||||
| app_id = trace_info.metadata.get("app_id") | |||||
| if not app_id: | |||||
| raise ValueError("No app_id found in trace_info metadata") | |||||
| app = session.query(App).filter(App.id == app_id).first() | |||||
| if not app: | |||||
| raise ValueError(f"App with id {app_id} not found") | |||||
| if not app.created_by: | |||||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||||
| if not service_account: | |||||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=session_factory, | |||||
| user=service_account, | |||||
| app_id=trace_info.metadata.get("app_id"), | |||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||||
| ) | ) | ||||
| for node_execution_id_record in workflow_nodes_execution_id_records: | |||||
| node_execution = ( | |||||
| db.session.query( | |||||
| WorkflowNodeExecution.id, | |||||
| WorkflowNodeExecution.tenant_id, | |||||
| WorkflowNodeExecution.app_id, | |||||
| WorkflowNodeExecution.title, | |||||
| WorkflowNodeExecution.node_type, | |||||
| WorkflowNodeExecution.status, | |||||
| WorkflowNodeExecution.inputs, | |||||
| WorkflowNodeExecution.outputs, | |||||
| WorkflowNodeExecution.created_at, | |||||
| WorkflowNodeExecution.elapsed_time, | |||||
| WorkflowNodeExecution.process_data, | |||||
| WorkflowNodeExecution.execution_metadata, | |||||
| ) | |||||
| .filter(WorkflowNodeExecution.id == node_execution_id_record.id) | |||||
| .first() | |||||
| ) | |||||
| if not node_execution: | |||||
| continue | |||||
| # Get all executions for this workflow run | |||||
| workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( | |||||
| workflow_run_id=trace_info.workflow_run_id | |||||
| ) | |||||
| for node_execution in workflow_node_executions: | |||||
| node_execution_id = node_execution.id | node_execution_id = node_execution.id | ||||
| tenant_id = node_execution.tenant_id | |||||
| app_id = node_execution.app_id | |||||
| tenant_id = trace_info.tenant_id # Use from trace_info instead | |||||
| app_id = trace_info.metadata.get("app_id") # Use from trace_info instead | |||||
| node_name = node_execution.title | node_name = node_execution.title | ||||
| node_type = node_execution.node_type | node_type = node_execution.node_type | ||||
| status = node_execution.status | status = node_execution.status | ||||
| if node_type == "llm": | |||||
| inputs = ( | |||||
| json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} | |||||
| ) | |||||
| if node_type == NodeType.LLM: | |||||
| inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} | |||||
| else: | else: | ||||
| inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} | |||||
| outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} | |||||
| inputs = node_execution.inputs if node_execution.inputs else {} | |||||
| outputs = node_execution.outputs if node_execution.outputs else {} | |||||
| created_at = node_execution.created_at or datetime.now() | created_at = node_execution.created_at or datetime.now() | ||||
| elapsed_time = node_execution.elapsed_time | elapsed_time = node_execution.elapsed_time | ||||
| finished_at = created_at + timedelta(seconds=elapsed_time) | finished_at = created_at + timedelta(seconds=elapsed_time) | ||||
| execution_metadata = ( | |||||
| json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} | |||||
| ) | |||||
| node_total_tokens = execution_metadata.get("total_tokens", 0) | |||||
| attributes = execution_metadata.copy() | |||||
| execution_metadata = node_execution.metadata if node_execution.metadata else {} | |||||
| node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 | |||||
| attributes = {str(k): v for k, v in execution_metadata.items()} | |||||
| attributes.update( | attributes.update( | ||||
| { | { | ||||
| "workflow_run_id": trace_info.workflow_run_id, | "workflow_run_id": trace_info.workflow_run_id, | ||||
| } | } | ||||
| ) | ) | ||||
| process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} | |||||
| process_data = node_execution.process_data if node_execution.process_data else {} | |||||
| if process_data and process_data.get("model_mode") == "chat": | if process_data and process_data.get("model_mode") == "chat": | ||||
| attributes.update( | attributes.update( | ||||
| { | { |
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| from models.enums import CreatedByRole | |||||
| from models.enums import CreatorUserRole | |||||
| from models.model import UploadFile | from models.model import UploadFile | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| extension=str(image_ext), | extension=str(image_ext), | ||||
| mime_type=mime_type or "", | mime_type=mime_type or "", | ||||
| created_by=self.user_id, | created_by=self.user_id, | ||||
| created_by_role=CreatedByRole.ACCOUNT, | |||||
| created_by_role=CreatorUserRole.ACCOUNT, | |||||
| created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | ||||
| used=True, | used=True, | ||||
| used_by=self.user_id, | used_by=self.user_id, |
| SQLAlchemy implementation of the WorkflowNodeExecutionRepository. | SQLAlchemy implementation of the WorkflowNodeExecutionRepository. | ||||
| """ | """ | ||||
| import json | |||||
| import logging | import logging | ||||
| from collections.abc import Sequence | from collections.abc import Sequence | ||||
| from typing import Optional | |||||
| from typing import Optional, Union | |||||
| from sqlalchemy import UnaryExpression, asc, delete, desc, select | from sqlalchemy import UnaryExpression, asc, delete, desc, select | ||||
| from sqlalchemy.engine import Engine | from sqlalchemy.engine import Engine | ||||
| from sqlalchemy.orm import sessionmaker | from sqlalchemy.orm import sessionmaker | ||||
| from core.workflow.entities.node_execution_entities import ( | |||||
| NodeExecution, | |||||
| NodeExecutionStatus, | |||||
| ) | |||||
| from core.workflow.nodes.enums import NodeType | |||||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | ||||
| from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom | |||||
| from models import ( | |||||
| Account, | |||||
| CreatorUserRole, | |||||
| EndUser, | |||||
| WorkflowNodeExecution, | |||||
| WorkflowNodeExecutionStatus, | |||||
| WorkflowNodeExecutionTriggeredFrom, | |||||
| ) | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| This implementation supports multi-tenancy by filtering operations based on tenant_id. | This implementation supports multi-tenancy by filtering operations based on tenant_id. | ||||
| Each method creates its own session, handles the transaction, and commits changes | Each method creates its own session, handles the transaction, and commits changes | ||||
| to the database. This prevents long-running connections in the workflow core. | to the database. This prevents long-running connections in the workflow core. | ||||
| This implementation also includes an in-memory cache for node executions to improve | |||||
| performance by reducing database queries. | |||||
| """ | """ | ||||
| def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None): | |||||
| def __init__( | |||||
| self, | |||||
| session_factory: sessionmaker | Engine, | |||||
| user: Union[Account, EndUser], | |||||
| app_id: Optional[str], | |||||
| triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], | |||||
| ): | |||||
| """ | """ | ||||
| Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context. | |||||
| Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. | |||||
| Args: | Args: | ||||
| session_factory: SQLAlchemy sessionmaker or engine for creating sessions | session_factory: SQLAlchemy sessionmaker or engine for creating sessions | ||||
| tenant_id: Tenant ID for multi-tenancy | |||||
| app_id: Optional app ID for filtering by application | |||||
| user: Account or EndUser object containing tenant_id, user ID, and role information | |||||
| app_id: App ID for filtering by application (can be None) | |||||
| triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) | |||||
| """ | """ | ||||
| # If an engine is provided, create a sessionmaker from it | # If an engine is provided, create a sessionmaker from it | ||||
| if isinstance(session_factory, Engine): | if isinstance(session_factory, Engine): | ||||
| f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" | f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" | ||||
| ) | ) | ||||
| # Extract tenant_id from user | |||||
| tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id | |||||
| if not tenant_id: | |||||
| raise ValueError("User must have a tenant_id or current_tenant_id") | |||||
| self._tenant_id = tenant_id | self._tenant_id = tenant_id | ||||
| # Store app context | |||||
| self._app_id = app_id | self._app_id = app_id | ||||
| def save(self, execution: WorkflowNodeExecution) -> None: | |||||
| # Extract user context | |||||
| self._triggered_from = triggered_from | |||||
| self._creator_user_id = user.id | |||||
| # Determine user role based on user type | |||||
| self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER | |||||
| # Initialize in-memory cache for node executions | |||||
| # Key: node_execution_id, Value: NodeExecution | |||||
| self._node_execution_cache: dict[str, NodeExecution] = {} | |||||
| def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution: | |||||
| """ | """ | ||||
| Save a WorkflowNodeExecution instance and commit changes to the database. | |||||
| Convert a database model to a domain model. | |||||
| Args: | Args: | ||||
| execution: The WorkflowNodeExecution instance to save | |||||
| db_model: The database model to convert | |||||
| Returns: | |||||
| The domain model | |||||
| """ | """ | ||||
| with self._session_factory() as session: | |||||
| # Ensure tenant_id is set | |||||
| if not execution.tenant_id: | |||||
| execution.tenant_id = self._tenant_id | |||||
| # Parse JSON fields | |||||
| inputs = db_model.inputs_dict | |||||
| process_data = db_model.process_data_dict | |||||
| outputs = db_model.outputs_dict | |||||
| metadata = db_model.execution_metadata_dict | |||||
| # Convert status to domain enum | |||||
| status = NodeExecutionStatus(db_model.status) | |||||
| return NodeExecution( | |||||
| id=db_model.id, | |||||
| node_execution_id=db_model.node_execution_id, | |||||
| workflow_id=db_model.workflow_id, | |||||
| workflow_run_id=db_model.workflow_run_id, | |||||
| index=db_model.index, | |||||
| predecessor_node_id=db_model.predecessor_node_id, | |||||
| node_id=db_model.node_id, | |||||
| node_type=NodeType(db_model.node_type), | |||||
| title=db_model.title, | |||||
| inputs=inputs, | |||||
| process_data=process_data, | |||||
| outputs=outputs, | |||||
| status=status, | |||||
| error=db_model.error, | |||||
| elapsed_time=db_model.elapsed_time, | |||||
| metadata=metadata, | |||||
| created_at=db_model.created_at, | |||||
| finished_at=db_model.finished_at, | |||||
| ) | |||||
| def _to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution: | |||||
| """ | |||||
| Convert a domain model to a database model. | |||||
| Args: | |||||
| domain_model: The domain model to convert | |||||
| Returns: | |||||
| The database model | |||||
| """ | |||||
| # Use values from constructor if provided | |||||
| if not self._triggered_from: | |||||
| raise ValueError("triggered_from is required in repository constructor") | |||||
| if not self._creator_user_id: | |||||
| raise ValueError("created_by is required in repository constructor") | |||||
| if not self._creator_user_role: | |||||
| raise ValueError("created_by_role is required in repository constructor") | |||||
| db_model = WorkflowNodeExecution() | |||||
| db_model.id = domain_model.id | |||||
| db_model.tenant_id = self._tenant_id | |||||
| if self._app_id is not None: | |||||
| db_model.app_id = self._app_id | |||||
| db_model.workflow_id = domain_model.workflow_id | |||||
| db_model.triggered_from = self._triggered_from | |||||
| db_model.workflow_run_id = domain_model.workflow_run_id | |||||
| db_model.index = domain_model.index | |||||
| db_model.predecessor_node_id = domain_model.predecessor_node_id | |||||
| db_model.node_execution_id = domain_model.node_execution_id | |||||
| db_model.node_id = domain_model.node_id | |||||
| db_model.node_type = domain_model.node_type | |||||
| db_model.title = domain_model.title | |||||
| db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None | |||||
| db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None | |||||
| db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None | |||||
| db_model.status = domain_model.status | |||||
| db_model.error = domain_model.error | |||||
| db_model.elapsed_time = domain_model.elapsed_time | |||||
| db_model.execution_metadata = json.dumps(domain_model.metadata) if domain_model.metadata else None | |||||
| db_model.created_at = domain_model.created_at | |||||
| db_model.created_by_role = self._creator_user_role | |||||
| db_model.created_by = self._creator_user_id | |||||
| db_model.finished_at = domain_model.finished_at | |||||
| return db_model | |||||
| def save(self, execution: NodeExecution) -> None: | |||||
| """ | |||||
| Save or update a NodeExecution instance and commit changes to the database. | |||||
| # Set app_id if provided and not already set | |||||
| if self._app_id and not execution.app_id: | |||||
| execution.app_id = self._app_id | |||||
| This method handles both creating new records and updating existing ones. | |||||
| It determines whether to create or update based on whether the record | |||||
| already exists in the database. It also updates the in-memory cache. | |||||
| session.add(execution) | |||||
| Args: | |||||
| execution: The NodeExecution instance to save or update | |||||
| """ | |||||
| with self._session_factory() as session: | |||||
| # Convert domain model to database model using instance attributes | |||||
| db_model = self._to_db_model(execution) | |||||
| # Use merge which will handle both insert and update | |||||
| session.merge(db_model) | |||||
| session.commit() | session.commit() | ||||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: | |||||
| # Update the cache if node_execution_id is present | |||||
| if execution.node_execution_id: | |||||
| logger.debug(f"Updating cache for node_execution_id: {execution.node_execution_id}") | |||||
| self._node_execution_cache[execution.node_execution_id] = execution | |||||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]: | |||||
| """ | """ | ||||
| Retrieve a WorkflowNodeExecution by its node_execution_id. | |||||
| Retrieve a NodeExecution by its node_execution_id. | |||||
| First checks the in-memory cache, and if not found, queries the database. | |||||
| If found in the database, adds it to the cache for future lookups. | |||||
| Args: | Args: | ||||
| node_execution_id: The node execution ID | node_execution_id: The node execution ID | ||||
| Returns: | Returns: | ||||
| The WorkflowNodeExecution instance if found, None otherwise | |||||
| The NodeExecution instance if found, None otherwise | |||||
| """ | """ | ||||
| # First check the cache | |||||
| if node_execution_id in self._node_execution_cache: | |||||
| logger.debug(f"Cache hit for node_execution_id: {node_execution_id}") | |||||
| return self._node_execution_cache[node_execution_id] | |||||
| # If not in cache, query the database | |||||
| logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database") | |||||
| with self._session_factory() as session: | with self._session_factory() as session: | ||||
| stmt = select(WorkflowNodeExecution).where( | stmt = select(WorkflowNodeExecution).where( | ||||
| WorkflowNodeExecution.node_execution_id == node_execution_id, | WorkflowNodeExecution.node_execution_id == node_execution_id, | ||||
| if self._app_id: | if self._app_id: | ||||
| stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) | stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) | ||||
| return session.scalar(stmt) | |||||
| db_model = session.scalar(stmt) | |||||
| if db_model: | |||||
| # Convert to domain model | |||||
| domain_model = self._to_domain_model(db_model) | |||||
| # Add to cache | |||||
| self._node_execution_cache[node_execution_id] = domain_model | |||||
| return domain_model | |||||
| return None | |||||
| def get_by_workflow_run( | def get_by_workflow_run( | ||||
| self, | self, | ||||
| workflow_run_id: str, | workflow_run_id: str, | ||||
| order_config: Optional[OrderConfig] = None, | order_config: Optional[OrderConfig] = None, | ||||
| ) -> Sequence[NodeExecution]: | |||||
| """ | |||||
| Retrieve all NodeExecution instances for a specific workflow run. | |||||
| This method always queries the database to ensure complete and ordered results, | |||||
| but updates the cache with any retrieved executions. | |||||
| Args: | |||||
| workflow_run_id: The workflow run ID | |||||
| order_config: Optional configuration for ordering results | |||||
| order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) | |||||
| order_config.order_direction: Direction to order ("asc" or "desc") | |||||
| Returns: | |||||
| A list of NodeExecution instances | |||||
| """ | |||||
| # Get the raw database models using the new method | |||||
| db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config) | |||||
| # Convert database models to domain models and update cache | |||||
| domain_models = [] | |||||
| for model in db_models: | |||||
| domain_model = self._to_domain_model(model) | |||||
| # Update cache if node_execution_id is present | |||||
| if domain_model.node_execution_id: | |||||
| self._node_execution_cache[domain_model.node_execution_id] = domain_model | |||||
| domain_models.append(domain_model) | |||||
| return domain_models | |||||
| def get_db_models_by_workflow_run( | |||||
| self, | |||||
| workflow_run_id: str, | |||||
| order_config: Optional[OrderConfig] = None, | |||||
| ) -> Sequence[WorkflowNodeExecution]: | ) -> Sequence[WorkflowNodeExecution]: | ||||
| """ | """ | ||||
| Retrieve all WorkflowNodeExecution instances for a specific workflow run. | |||||
| Retrieve all WorkflowNodeExecution database models for a specific workflow run. | |||||
| This method is similar to get_by_workflow_run but returns the raw database models | |||||
| instead of converting them to domain models. This can be useful when direct access | |||||
| to database model properties is needed. | |||||
| Args: | Args: | ||||
| workflow_run_id: The workflow run ID | workflow_run_id: The workflow run ID | ||||
| order_config.order_direction: Direction to order ("asc" or "desc") | order_config.order_direction: Direction to order ("asc" or "desc") | ||||
| Returns: | Returns: | ||||
| A list of WorkflowNodeExecution instances | |||||
| A list of WorkflowNodeExecution database models | |||||
| """ | """ | ||||
| with self._session_factory() as session: | with self._session_factory() as session: | ||||
| stmt = select(WorkflowNodeExecution).where( | stmt = select(WorkflowNodeExecution).where( | ||||
| if order_columns: | if order_columns: | ||||
| stmt = stmt.order_by(*order_columns) | stmt = stmt.order_by(*order_columns) | ||||
| return session.scalars(stmt).all() | |||||
| db_models = session.scalars(stmt).all() | |||||
| # Note: We don't update the cache here since we're returning raw DB models | |||||
| # and not converting to domain models | |||||
| return db_models | |||||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: | |||||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]: | |||||
| """ | """ | ||||
| Retrieve all running WorkflowNodeExecution instances for a specific workflow run. | |||||
| Retrieve all running NodeExecution instances for a specific workflow run. | |||||
| This method queries the database directly and updates the cache with any | |||||
| retrieved executions that have a node_execution_id. | |||||
| Args: | Args: | ||||
| workflow_run_id: The workflow run ID | workflow_run_id: The workflow run ID | ||||
| Returns: | Returns: | ||||
| A list of running WorkflowNodeExecution instances | |||||
| A list of running NodeExecution instances | |||||
| """ | """ | ||||
| with self._session_factory() as session: | with self._session_factory() as session: | ||||
| stmt = select(WorkflowNodeExecution).where( | stmt = select(WorkflowNodeExecution).where( | ||||
| if self._app_id: | if self._app_id: | ||||
| stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) | stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) | ||||
| return session.scalars(stmt).all() | |||||
| db_models = session.scalars(stmt).all() | |||||
| domain_models = [] | |||||
| def update(self, execution: WorkflowNodeExecution) -> None: | |||||
| """ | |||||
| Update an existing WorkflowNodeExecution instance and commit changes to the database. | |||||
| for model in db_models: | |||||
| domain_model = self._to_domain_model(model) | |||||
| # Update cache if node_execution_id is present | |||||
| if domain_model.node_execution_id: | |||||
| self._node_execution_cache[domain_model.node_execution_id] = domain_model | |||||
| domain_models.append(domain_model) | |||||
| Args: | |||||
| execution: The WorkflowNodeExecution instance to update | |||||
| """ | |||||
| with self._session_factory() as session: | |||||
| # Ensure tenant_id is set | |||||
| if not execution.tenant_id: | |||||
| execution.tenant_id = self._tenant_id | |||||
| # Set app_id if provided and not already set | |||||
| if self._app_id and not execution.app_id: | |||||
| execution.app_id = self._app_id | |||||
| session.merge(execution) | |||||
| session.commit() | |||||
| return domain_models | |||||
| def clear(self) -> None: | def clear(self) -> None: | ||||
| """ | """ | ||||
| This method deletes all WorkflowNodeExecution records that match the tenant_id | This method deletes all WorkflowNodeExecution records that match the tenant_id | ||||
| and app_id (if provided) associated with this repository instance. | and app_id (if provided) associated with this repository instance. | ||||
| It also clears the in-memory cache. | |||||
| """ | """ | ||||
| with self._session_factory() as session: | with self._session_factory() as session: | ||||
| stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) | stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) | ||||
| f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" | f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" | ||||
| + (f" and app {self._app_id}" if self._app_id else "") | + (f" and app {self._app_id}" if self._app_id else "") | ||||
| ) | ) | ||||
| # Clear the in-memory cache | |||||
| self._node_execution_cache.clear() | |||||
| logger.info("Cleared in-memory node execution cache") |
| from core.tools.utils.message_transformer import ToolFileMessageTransformer | from core.tools.utils.message_transformer import ToolFileMessageTransformer | ||||
| from core.tools.workflow_as_tool.tool import WorkflowTool | from core.tools.workflow_as_tool.tool import WorkflowTool | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.enums import CreatedByRole | |||||
| from models.enums import CreatorUserRole | |||||
| from models.model import Message, MessageFile | from models.model import Message, MessageFile | ||||
| url=message.url, | url=message.url, | ||||
| upload_file_id=tool_file_id, | upload_file_id=tool_file_id, | ||||
| created_by_role=( | created_by_role=( | ||||
| CreatedByRole.ACCOUNT | |||||
| CreatorUserRole.ACCOUNT | |||||
| if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | ||||
| else CreatedByRole.END_USER | |||||
| else CreatorUserRole.END_USER | |||||
| ), | ), | ||||
| created_by=user_id, | created_by=user_id, | ||||
| ) | ) |
| """ | |||||
| Domain entities for workflow node execution. | |||||
| This module contains the domain model for workflow node execution, which is used | |||||
| by the core workflow module. These models are independent of the storage mechanism | |||||
| and don't contain implementation details like tenant_id, app_id, etc. | |||||
| """ | |||||
| from collections.abc import Mapping | |||||
| from datetime import datetime | |||||
| from enum import StrEnum | |||||
| from typing import Any, Optional | |||||
| from pydantic import BaseModel, Field | |||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||||
| from core.workflow.nodes.enums import NodeType | |||||
| class NodeExecutionStatus(StrEnum): | |||||
| """ | |||||
| Node Execution Status Enum. | |||||
| """ | |||||
| RUNNING = "running" | |||||
| SUCCEEDED = "succeeded" | |||||
| FAILED = "failed" | |||||
| EXCEPTION = "exception" | |||||
| RETRY = "retry" | |||||
| class NodeExecution(BaseModel): | |||||
| """ | |||||
| Domain model for workflow node execution. | |||||
| This model represents the core business entity of a node execution, | |||||
| without implementation details like tenant_id, app_id, etc. | |||||
| Note: User/context-specific fields (triggered_from, created_by, created_by_role) | |||||
| have been moved to the repository implementation to keep the domain model clean. | |||||
| These fields are still accepted in the constructor for backward compatibility, | |||||
| but they are not stored in the model. | |||||
| """ | |||||
| # Core identification fields | |||||
| id: str # Unique identifier for this execution record | |||||
| node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing | |||||
| workflow_id: str # ID of the workflow this node belongs to | |||||
| workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging) | |||||
| # Execution positioning and flow | |||||
| index: int # Sequence number for ordering in trace visualization | |||||
| predecessor_node_id: Optional[str] = None # ID of the node that executed before this one | |||||
| node_id: str # ID of the node being executed | |||||
| node_type: NodeType # Type of node (e.g., start, llm, knowledge) | |||||
| title: str # Display title of the node | |||||
| # Execution data | |||||
| inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node | |||||
| process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data | |||||
| outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node | |||||
| # Execution state | |||||
| status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status | |||||
| error: Optional[str] = None # Error message if execution failed | |||||
| elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds | |||||
| # Additional metadata | |||||
| metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) | |||||
| # Timing information | |||||
| created_at: datetime # When execution started | |||||
| finished_at: Optional[datetime] = None # When execution completed | |||||
| def update_from_mapping( | |||||
| self, | |||||
| inputs: Optional[Mapping[str, Any]] = None, | |||||
| process_data: Optional[Mapping[str, Any]] = None, | |||||
| outputs: Optional[Mapping[str, Any]] = None, | |||||
| metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None, | |||||
| ) -> None: | |||||
| """ | |||||
| Update the model from mappings. | |||||
| Args: | |||||
| inputs: The inputs to update | |||||
| process_data: The process data to update | |||||
| outputs: The outputs to update | |||||
| metadata: The metadata to update | |||||
| """ | |||||
| if inputs is not None: | |||||
| self.inputs = dict(inputs) | |||||
| if process_data is not None: | |||||
| self.process_data = dict(process_data) | |||||
| if outputs is not None: | |||||
| self.outputs = dict(outputs) | |||||
| if metadata is not None: | |||||
| self.metadata = dict(metadata) |
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| from typing import Literal, Optional, Protocol | from typing import Literal, Optional, Protocol | ||||
| from models.workflow import WorkflowNodeExecution | |||||
| from core.workflow.entities.node_execution_entities import NodeExecution | |||||
| @dataclass | @dataclass | ||||
| class OrderConfig: | class OrderConfig: | ||||
| """Configuration for ordering WorkflowNodeExecution instances.""" | |||||
| """Configuration for ordering NodeExecution instances.""" | |||||
| order_by: list[str] | order_by: list[str] | ||||
| order_direction: Optional[Literal["asc", "desc"]] = None | order_direction: Optional[Literal["asc", "desc"]] = None | ||||
| class WorkflowNodeExecutionRepository(Protocol): | class WorkflowNodeExecutionRepository(Protocol): | ||||
| """ | """ | ||||
| Repository interface for WorkflowNodeExecution. | |||||
| Repository interface for NodeExecution. | |||||
| This interface defines the contract for accessing and manipulating | This interface defines the contract for accessing and manipulating | ||||
| WorkflowNodeExecution data, regardless of the underlying storage mechanism. | |||||
| NodeExecution data, regardless of the underlying storage mechanism. | |||||
| Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), | Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), | ||||
| and trigger sources (triggered_from) should be handled at the implementation level, not in | and trigger sources (triggered_from) should be handled at the implementation level, not in | ||||
| application domains or deployment scenarios. | application domains or deployment scenarios. | ||||
| """ | """ | ||||
| def save(self, execution: WorkflowNodeExecution) -> None: | |||||
| def save(self, execution: NodeExecution) -> None: | |||||
| """ | """ | ||||
| Save a WorkflowNodeExecution instance. | |||||
| Save or update a NodeExecution instance. | |||||
| This method handles both creating new records and updating existing ones. | |||||
| The implementation should determine whether to create or update based on | |||||
| the execution's ID or other identifying fields. | |||||
| Args: | Args: | ||||
| execution: The WorkflowNodeExecution instance to save | |||||
| execution: The NodeExecution instance to save or update | |||||
| """ | """ | ||||
| ... | ... | ||||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: | |||||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]: | |||||
| """ | """ | ||||
| Retrieve a WorkflowNodeExecution by its node_execution_id. | |||||
| Retrieve a NodeExecution by its node_execution_id. | |||||
| Args: | Args: | ||||
| node_execution_id: The node execution ID | node_execution_id: The node execution ID | ||||
| Returns: | Returns: | ||||
| The WorkflowNodeExecution instance if found, None otherwise | |||||
| The NodeExecution instance if found, None otherwise | |||||
| """ | """ | ||||
| ... | ... | ||||
| self, | self, | ||||
| workflow_run_id: str, | workflow_run_id: str, | ||||
| order_config: Optional[OrderConfig] = None, | order_config: Optional[OrderConfig] = None, | ||||
| ) -> Sequence[WorkflowNodeExecution]: | |||||
| ) -> Sequence[NodeExecution]: | |||||
| """ | """ | ||||
| Retrieve all WorkflowNodeExecution instances for a specific workflow run. | |||||
| Retrieve all NodeExecution instances for a specific workflow run. | |||||
| Args: | Args: | ||||
| workflow_run_id: The workflow run ID | workflow_run_id: The workflow run ID | ||||
| order_config.order_direction: Direction to order ("asc" or "desc") | order_config.order_direction: Direction to order ("asc" or "desc") | ||||
| Returns: | Returns: | ||||
| A list of WorkflowNodeExecution instances | |||||
| A list of NodeExecution instances | |||||
| """ | """ | ||||
| ... | ... | ||||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: | |||||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]: | |||||
| """ | """ | ||||
| Retrieve all running WorkflowNodeExecution instances for a specific workflow run. | |||||
| Retrieve all running NodeExecution instances for a specific workflow run. | |||||
| Args: | Args: | ||||
| workflow_run_id: The workflow run ID | workflow_run_id: The workflow run ID | ||||
| Returns: | Returns: | ||||
| A list of running WorkflowNodeExecution instances | |||||
| """ | |||||
| ... | |||||
| def update(self, execution: WorkflowNodeExecution) -> None: | |||||
| """ | |||||
| Update an existing WorkflowNodeExecution instance. | |||||
| Args: | |||||
| execution: The WorkflowNodeExecution instance to update | |||||
| A list of running NodeExecution instances | |||||
| """ | """ | ||||
| ... | ... | ||||
| def clear(self) -> None: | def clear(self) -> None: | ||||
| """ | """ | ||||
| Clear all WorkflowNodeExecution records based on implementation-specific criteria. | |||||
| Clear all NodeExecution records based on implementation-specific criteria. | |||||
| This method is intended to be used for bulk deletion operations, such as removing | This method is intended to be used for bulk deletion operations, such as removing | ||||
| all records associated with a specific app_id and tenant_id in multi-tenant implementations. | all records associated with a specific app_id and tenant_id in multi-tenant implementations. |
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | from core.workflow.workflow_cycle_manager import WorkflowCycleManager | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.account import Account | from models.account import Account | ||||
| from models.enums import CreatedByRole | |||||
| from models.enums import CreatorUserRole | |||||
| from models.model import EndUser | from models.model import EndUser | ||||
| from models.workflow import ( | from models.workflow import ( | ||||
| Workflow, | Workflow, | ||||
| if isinstance(user, EndUser): | if isinstance(user, EndUser): | ||||
| self._user_id = user.id | self._user_id = user.id | ||||
| user_session_id = user.session_id | user_session_id = user.session_id | ||||
| self._created_by_role = CreatedByRole.END_USER | |||||
| self._created_by_role = CreatorUserRole.END_USER | |||||
| elif isinstance(user, Account): | elif isinstance(user, Account): | ||||
| self._user_id = user.id | self._user_id = user.id | ||||
| user_session_id = user.id | user_session_id = user.id | ||||
| self._created_by_role = CreatedByRole.ACCOUNT | |||||
| self._created_by_role = CreatorUserRole.ACCOUNT | |||||
| else: | else: | ||||
| raise ValueError(f"Invalid user type: {type(user)}") | raise ValueError(f"Invalid user type: {type(user)}") | ||||
| ) | ) | ||||
| from core.app.task_pipeline.exc import WorkflowRunNotFoundError | from core.app.task_pipeline.exc import WorkflowRunNotFoundError | ||||
| from core.file import FILE_MODEL_IDENTITY, File | from core.file import FILE_MODEL_IDENTITY, File | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||||
| from core.ops.entities.trace_entity import TraceTaskName | from core.ops.entities.trace_entity import TraceTaskName | ||||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | ||||
| from core.tools.tool_manager import ToolManager | from core.tools.tool_manager import ToolManager | ||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | from core.workflow.entities.node_entities import NodeRunMetadataKey | ||||
| from core.workflow.entities.node_execution_entities import ( | |||||
| NodeExecution, | |||||
| NodeExecutionStatus, | |||||
| ) | |||||
| from core.workflow.enums import SystemVariableKey | from core.workflow.enums import SystemVariableKey | ||||
| from core.workflow.nodes import NodeType | from core.workflow.nodes import NodeType | ||||
| from core.workflow.nodes.tool.entities import ToolNodeData | from core.workflow.nodes.tool.entities import ToolNodeData | ||||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | ||||
| from core.workflow.workflow_entry import WorkflowEntry | from core.workflow.workflow_entry import WorkflowEntry | ||||
| from models.account import Account | |||||
| from models.enums import CreatedByRole, WorkflowRunTriggeredFrom | |||||
| from models.model import EndUser | |||||
| from models.workflow import ( | |||||
| from models import ( | |||||
| Account, | |||||
| CreatorUserRole, | |||||
| EndUser, | |||||
| Workflow, | Workflow, | ||||
| WorkflowNodeExecution, | |||||
| WorkflowNodeExecutionStatus, | WorkflowNodeExecutionStatus, | ||||
| WorkflowNodeExecutionTriggeredFrom, | |||||
| WorkflowRun, | WorkflowRun, | ||||
| WorkflowRunStatus, | WorkflowRunStatus, | ||||
| WorkflowRunTriggeredFrom, | |||||
| ) | ) | ||||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | workflow_node_execution_repository: WorkflowNodeExecutionRepository, | ||||
| ) -> None: | ) -> None: | ||||
| self._workflow_run: WorkflowRun | None = None | self._workflow_run: WorkflowRun | None = None | ||||
| self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} | |||||
| self._application_generate_entity = application_generate_entity | self._application_generate_entity = application_generate_entity | ||||
| self._workflow_system_variables = workflow_system_variables | self._workflow_system_variables = workflow_system_variables | ||||
| self._workflow_node_execution_repository = workflow_node_execution_repository | self._workflow_node_execution_repository = workflow_node_execution_repository | ||||
| session: Session, | session: Session, | ||||
| workflow_id: str, | workflow_id: str, | ||||
| user_id: str, | user_id: str, | ||||
| created_by_role: CreatedByRole, | |||||
| created_by_role: CreatorUserRole, | |||||
| ) -> WorkflowRun: | ) -> WorkflowRun: | ||||
| workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) | workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) | ||||
| workflow = session.scalar(workflow_stmt) | workflow = session.scalar(workflow_stmt) | ||||
| workflow_run.exceptions_count = exceptions_count | workflow_run.exceptions_count = exceptions_count | ||||
| # Use the instance repository to find running executions for a workflow run | # Use the instance repository to find running executions for a workflow run | ||||
| running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions( | |||||
| running_domain_executions = self._workflow_node_execution_repository.get_running_executions( | |||||
| workflow_run_id=workflow_run.id | workflow_run_id=workflow_run.id | ||||
| ) | ) | ||||
| # Update the cache with the retrieved executions | |||||
| for execution in running_workflow_node_executions: | |||||
| if execution.node_execution_id: | |||||
| self._workflow_node_executions[execution.node_execution_id] = execution | |||||
| # Update the domain models | |||||
| now = datetime.now(UTC).replace(tzinfo=None) | |||||
| for domain_execution in running_domain_executions: | |||||
| if domain_execution.node_execution_id: | |||||
| # Update the domain model | |||||
| domain_execution.status = NodeExecutionStatus.FAILED | |||||
| domain_execution.error = error | |||||
| domain_execution.finished_at = now | |||||
| domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds() | |||||
| for workflow_node_execution in running_workflow_node_executions: | |||||
| now = datetime.now(UTC).replace(tzinfo=None) | |||||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | |||||
| workflow_node_execution.error = error | |||||
| workflow_node_execution.finished_at = now | |||||
| workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() | |||||
| # Update the repository with the domain model | |||||
| self._workflow_node_execution_repository.save(domain_execution) | |||||
| if trace_manager: | if trace_manager: | ||||
| trace_manager.add_trace_task( | trace_manager.add_trace_task( | ||||
| return workflow_run | return workflow_run | ||||
| def _handle_node_execution_start( | |||||
| self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent | |||||
| ) -> WorkflowNodeExecution: | |||||
| workflow_node_execution = WorkflowNodeExecution() | |||||
| workflow_node_execution.id = str(uuid4()) | |||||
| workflow_node_execution.tenant_id = workflow_run.tenant_id | |||||
| workflow_node_execution.app_id = workflow_run.app_id | |||||
| workflow_node_execution.workflow_id = workflow_run.workflow_id | |||||
| workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value | |||||
| workflow_node_execution.workflow_run_id = workflow_run.id | |||||
| workflow_node_execution.predecessor_node_id = event.predecessor_node_id | |||||
| workflow_node_execution.index = event.node_run_index | |||||
| workflow_node_execution.node_execution_id = event.node_execution_id | |||||
| workflow_node_execution.node_id = event.node_id | |||||
| workflow_node_execution.node_type = event.node_type.value | |||||
| workflow_node_execution.title = event.node_data.title | |||||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value | |||||
| workflow_node_execution.created_by_role = workflow_run.created_by_role | |||||
| workflow_node_execution.created_by = workflow_run.created_by | |||||
| workflow_node_execution.execution_metadata = json.dumps( | |||||
| { | |||||
| NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | |||||
| NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, | |||||
| NodeRunMetadataKey.LOOP_ID: event.in_loop_id, | |||||
| } | |||||
| def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution: | |||||
| # Create a domain model | |||||
| created_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| metadata = { | |||||
| NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | |||||
| NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, | |||||
| NodeRunMetadataKey.LOOP_ID: event.in_loop_id, | |||||
| } | |||||
| domain_execution = NodeExecution( | |||||
| id=str(uuid4()), | |||||
| workflow_id=workflow_run.workflow_id, | |||||
| workflow_run_id=workflow_run.id, | |||||
| predecessor_node_id=event.predecessor_node_id, | |||||
| index=event.node_run_index, | |||||
| node_execution_id=event.node_execution_id, | |||||
| node_id=event.node_id, | |||||
| node_type=event.node_type, | |||||
| title=event.node_data.title, | |||||
| status=NodeExecutionStatus.RUNNING, | |||||
| metadata=metadata, | |||||
| created_at=created_at, | |||||
| ) | ) | ||||
| workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| # Use the instance repository to save the workflow node execution | |||||
| self._workflow_node_execution_repository.save(workflow_node_execution) | |||||
| # Use the instance repository to save the domain model | |||||
| self._workflow_node_execution_repository.save(domain_execution) | |||||
| return domain_execution | |||||
| self._workflow_node_executions[event.node_execution_id] = workflow_node_execution | |||||
| return workflow_node_execution | |||||
| def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution: | |||||
| # Get the domain model from repository | |||||
| domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) | |||||
| if not domain_execution: | |||||
| raise ValueError(f"Domain node execution not found: {event.node_execution_id}") | |||||
| def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: | |||||
| workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) | |||||
| # Process data | |||||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | inputs = WorkflowEntry.handle_special_values(event.inputs) | ||||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | process_data = WorkflowEntry.handle_special_values(event.process_data) | ||||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | outputs = WorkflowEntry.handle_special_values(event.outputs) | ||||
| execution_metadata_dict = dict(event.execution_metadata or {}) | |||||
| execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None | |||||
| # Convert metadata keys to strings | |||||
| execution_metadata_dict = {} | |||||
| if event.execution_metadata: | |||||
| for key, value in event.execution_metadata.items(): | |||||
| execution_metadata_dict[key] = value | |||||
| finished_at = datetime.now(UTC).replace(tzinfo=None) | finished_at = datetime.now(UTC).replace(tzinfo=None) | ||||
| elapsed_time = (finished_at - event.start_at).total_seconds() | elapsed_time = (finished_at - event.start_at).total_seconds() | ||||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||||
| # Update domain model | |||||
| domain_execution.status = NodeExecutionStatus.SUCCEEDED | |||||
| domain_execution.update_from_mapping( | |||||
| inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict | |||||
| ) | |||||
| domain_execution.finished_at = finished_at | |||||
| domain_execution.elapsed_time = elapsed_time | |||||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value | |||||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||||
| workflow_node_execution.process_data = json.dumps(process_data) if process_data else None | |||||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||||
| workflow_node_execution.execution_metadata = execution_metadata | |||||
| workflow_node_execution.finished_at = finished_at | |||||
| workflow_node_execution.elapsed_time = elapsed_time | |||||
| # Update the repository with the domain model | |||||
| self._workflow_node_execution_repository.save(domain_execution) | |||||
| # Use the instance repository to update the workflow node execution | |||||
| self._workflow_node_execution_repository.update(workflow_node_execution) | |||||
| return workflow_node_execution | |||||
| return domain_execution | |||||
| def _handle_workflow_node_execution_failed( | def _handle_workflow_node_execution_failed( | ||||
| self, | self, | ||||
| | QueueNodeInIterationFailedEvent | | QueueNodeInIterationFailedEvent | ||||
| | QueueNodeInLoopFailedEvent | | QueueNodeInLoopFailedEvent | ||||
| | QueueNodeExceptionEvent, | | QueueNodeExceptionEvent, | ||||
| ) -> WorkflowNodeExecution: | |||||
| ) -> NodeExecution: | |||||
| """ | """ | ||||
| Workflow node execution failed | Workflow node execution failed | ||||
| :param event: queue node failed event | :param event: queue node failed event | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) | |||||
| # Get the domain model from repository | |||||
| domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) | |||||
| if not domain_execution: | |||||
| raise ValueError(f"Domain node execution not found: {event.node_execution_id}") | |||||
| # Process data | |||||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | inputs = WorkflowEntry.handle_special_values(event.inputs) | ||||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | process_data = WorkflowEntry.handle_special_values(event.process_data) | ||||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | outputs = WorkflowEntry.handle_special_values(event.outputs) | ||||
| # Convert metadata keys to strings | |||||
| execution_metadata_dict = {} | |||||
| if event.execution_metadata: | |||||
| for key, value in event.execution_metadata.items(): | |||||
| execution_metadata_dict[key] = value | |||||
| finished_at = datetime.now(UTC).replace(tzinfo=None) | finished_at = datetime.now(UTC).replace(tzinfo=None) | ||||
| elapsed_time = (finished_at - event.start_at).total_seconds() | elapsed_time = (finished_at - event.start_at).total_seconds() | ||||
| execution_metadata = ( | |||||
| json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None | |||||
| ) | |||||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||||
| workflow_node_execution.status = ( | |||||
| WorkflowNodeExecutionStatus.FAILED.value | |||||
| # Update domain model | |||||
| domain_execution.status = ( | |||||
| NodeExecutionStatus.FAILED | |||||
| if not isinstance(event, QueueNodeExceptionEvent) | if not isinstance(event, QueueNodeExceptionEvent) | ||||
| else WorkflowNodeExecutionStatus.EXCEPTION.value | |||||
| else NodeExecutionStatus.EXCEPTION | |||||
| ) | ) | ||||
| workflow_node_execution.error = event.error | |||||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||||
| workflow_node_execution.process_data = json.dumps(process_data) if process_data else None | |||||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||||
| workflow_node_execution.finished_at = finished_at | |||||
| workflow_node_execution.elapsed_time = elapsed_time | |||||
| workflow_node_execution.execution_metadata = execution_metadata | |||||
| domain_execution.error = event.error | |||||
| domain_execution.update_from_mapping( | |||||
| inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict | |||||
| ) | |||||
| domain_execution.finished_at = finished_at | |||||
| domain_execution.elapsed_time = elapsed_time | |||||
| self._workflow_node_execution_repository.update(workflow_node_execution) | |||||
| # Update the repository with the domain model | |||||
| self._workflow_node_execution_repository.save(domain_execution) | |||||
| return workflow_node_execution | |||||
| return domain_execution | |||||
| def _handle_workflow_node_execution_retried( | def _handle_workflow_node_execution_retried( | ||||
| self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent | self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent | ||||
| ) -> WorkflowNodeExecution: | |||||
| ) -> NodeExecution: | |||||
| """ | """ | ||||
| Workflow node execution failed | Workflow node execution failed | ||||
| :param workflow_run: workflow run | :param workflow_run: workflow run | ||||
| elapsed_time = (finished_at - created_at).total_seconds() | elapsed_time = (finished_at - created_at).total_seconds() | ||||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | inputs = WorkflowEntry.handle_special_values(event.inputs) | ||||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | outputs = WorkflowEntry.handle_special_values(event.outputs) | ||||
| # Convert metadata keys to strings | |||||
| origin_metadata = { | origin_metadata = { | ||||
| NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, | NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, | ||||
| NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | ||||
| NodeRunMetadataKey.LOOP_ID: event.in_loop_id, | NodeRunMetadataKey.LOOP_ID: event.in_loop_id, | ||||
| } | } | ||||
| merged_metadata = ( | |||||
| {**jsonable_encoder(event.execution_metadata), **origin_metadata} | |||||
| if event.execution_metadata is not None | |||||
| else origin_metadata | |||||
| # Convert execution metadata keys to strings | |||||
| execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {} | |||||
| if event.execution_metadata: | |||||
| for key, value in event.execution_metadata.items(): | |||||
| execution_metadata_dict[key] = value | |||||
| merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata | |||||
| # Create a domain model | |||||
| domain_execution = NodeExecution( | |||||
| id=str(uuid4()), | |||||
| workflow_id=workflow_run.workflow_id, | |||||
| workflow_run_id=workflow_run.id, | |||||
| predecessor_node_id=event.predecessor_node_id, | |||||
| node_execution_id=event.node_execution_id, | |||||
| node_id=event.node_id, | |||||
| node_type=event.node_type, | |||||
| title=event.node_data.title, | |||||
| status=NodeExecutionStatus.RETRY, | |||||
| created_at=created_at, | |||||
| finished_at=finished_at, | |||||
| elapsed_time=elapsed_time, | |||||
| error=event.error, | |||||
| index=event.node_run_index, | |||||
| ) | ) | ||||
| execution_metadata = json.dumps(merged_metadata) | |||||
| workflow_node_execution = WorkflowNodeExecution() | |||||
| workflow_node_execution.id = str(uuid4()) | |||||
| workflow_node_execution.tenant_id = workflow_run.tenant_id | |||||
| workflow_node_execution.app_id = workflow_run.app_id | |||||
| workflow_node_execution.workflow_id = workflow_run.workflow_id | |||||
| workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value | |||||
| workflow_node_execution.workflow_run_id = workflow_run.id | |||||
| workflow_node_execution.predecessor_node_id = event.predecessor_node_id | |||||
| workflow_node_execution.node_execution_id = event.node_execution_id | |||||
| workflow_node_execution.node_id = event.node_id | |||||
| workflow_node_execution.node_type = event.node_type.value | |||||
| workflow_node_execution.title = event.node_data.title | |||||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value | |||||
| workflow_node_execution.created_by_role = workflow_run.created_by_role | |||||
| workflow_node_execution.created_by = workflow_run.created_by | |||||
| workflow_node_execution.created_at = created_at | |||||
| workflow_node_execution.finished_at = finished_at | |||||
| workflow_node_execution.elapsed_time = elapsed_time | |||||
| workflow_node_execution.error = event.error | |||||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||||
| workflow_node_execution.execution_metadata = execution_metadata | |||||
| workflow_node_execution.index = event.node_run_index | |||||
| # Use the instance repository to save the workflow node execution | |||||
| self._workflow_node_execution_repository.save(workflow_node_execution) | |||||
| self._workflow_node_executions[event.node_execution_id] = workflow_node_execution | |||||
| return workflow_node_execution | |||||
| # Update with mappings | |||||
| domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata) | |||||
| # Use the instance repository to save the domain model | |||||
| self._workflow_node_execution_repository.save(domain_execution) | |||||
| return domain_execution | |||||
| def _workflow_start_to_stream_response( | def _workflow_start_to_stream_response( | ||||
| self, | self, | ||||
| workflow_run: WorkflowRun, | workflow_run: WorkflowRun, | ||||
| ) -> WorkflowFinishStreamResponse: | ) -> WorkflowFinishStreamResponse: | ||||
| created_by = None | created_by = None | ||||
| if workflow_run.created_by_role == CreatedByRole.ACCOUNT: | |||||
| if workflow_run.created_by_role == CreatorUserRole.ACCOUNT: | |||||
| stmt = select(Account).where(Account.id == workflow_run.created_by) | stmt = select(Account).where(Account.id == workflow_run.created_by) | ||||
| account = session.scalar(stmt) | account = session.scalar(stmt) | ||||
| if account: | if account: | ||||
| "name": account.name, | "name": account.name, | ||||
| "email": account.email, | "email": account.email, | ||||
| } | } | ||||
| elif workflow_run.created_by_role == CreatedByRole.END_USER: | |||||
| elif workflow_run.created_by_role == CreatorUserRole.END_USER: | |||||
| stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) | stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) | ||||
| end_user = session.scalar(stmt) | end_user = session.scalar(stmt) | ||||
| if end_user: | if end_user: | ||||
| *, | *, | ||||
| event: QueueNodeStartedEvent, | event: QueueNodeStartedEvent, | ||||
| task_id: str, | task_id: str, | ||||
| workflow_node_execution: WorkflowNodeExecution, | |||||
| workflow_node_execution: NodeExecution, | |||||
| ) -> Optional[NodeStartStreamResponse]: | ) -> Optional[NodeStartStreamResponse]: | ||||
| if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: | |||||
| if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: | |||||
| return None | return None | ||||
| if not workflow_node_execution.workflow_run_id: | if not workflow_node_execution.workflow_run_id: | ||||
| return None | return None | ||||
| title=workflow_node_execution.title, | title=workflow_node_execution.title, | ||||
| index=workflow_node_execution.index, | index=workflow_node_execution.index, | ||||
| predecessor_node_id=workflow_node_execution.predecessor_node_id, | predecessor_node_id=workflow_node_execution.predecessor_node_id, | ||||
| inputs=workflow_node_execution.inputs_dict, | |||||
| inputs=workflow_node_execution.inputs, | |||||
| created_at=int(workflow_node_execution.created_at.timestamp()), | created_at=int(workflow_node_execution.created_at.timestamp()), | ||||
| parallel_id=event.parallel_id, | parallel_id=event.parallel_id, | ||||
| parallel_start_node_id=event.parallel_start_node_id, | parallel_start_node_id=event.parallel_start_node_id, | ||||
| | QueueNodeInLoopFailedEvent | | QueueNodeInLoopFailedEvent | ||||
| | QueueNodeExceptionEvent, | | QueueNodeExceptionEvent, | ||||
| task_id: str, | task_id: str, | ||||
| workflow_node_execution: WorkflowNodeExecution, | |||||
| workflow_node_execution: NodeExecution, | |||||
| ) -> Optional[NodeFinishStreamResponse]: | ) -> Optional[NodeFinishStreamResponse]: | ||||
| if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: | |||||
| if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: | |||||
| return None | return None | ||||
| if not workflow_node_execution.workflow_run_id: | if not workflow_node_execution.workflow_run_id: | ||||
| return None | return None | ||||
| index=workflow_node_execution.index, | index=workflow_node_execution.index, | ||||
| title=workflow_node_execution.title, | title=workflow_node_execution.title, | ||||
| predecessor_node_id=workflow_node_execution.predecessor_node_id, | predecessor_node_id=workflow_node_execution.predecessor_node_id, | ||||
| inputs=workflow_node_execution.inputs_dict, | |||||
| process_data=workflow_node_execution.process_data_dict, | |||||
| outputs=workflow_node_execution.outputs_dict, | |||||
| inputs=workflow_node_execution.inputs, | |||||
| process_data=workflow_node_execution.process_data, | |||||
| outputs=workflow_node_execution.outputs, | |||||
| status=workflow_node_execution.status, | status=workflow_node_execution.status, | ||||
| error=workflow_node_execution.error, | error=workflow_node_execution.error, | ||||
| elapsed_time=workflow_node_execution.elapsed_time, | elapsed_time=workflow_node_execution.elapsed_time, | ||||
| execution_metadata=workflow_node_execution.execution_metadata_dict, | |||||
| execution_metadata=workflow_node_execution.metadata, | |||||
| created_at=int(workflow_node_execution.created_at.timestamp()), | created_at=int(workflow_node_execution.created_at.timestamp()), | ||||
| finished_at=int(workflow_node_execution.finished_at.timestamp()), | finished_at=int(workflow_node_execution.finished_at.timestamp()), | ||||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), | |||||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), | |||||
| parallel_id=event.parallel_id, | parallel_id=event.parallel_id, | ||||
| parallel_start_node_id=event.parallel_start_node_id, | parallel_start_node_id=event.parallel_start_node_id, | ||||
| parent_parallel_id=event.parent_parallel_id, | parent_parallel_id=event.parent_parallel_id, | ||||
| *, | *, | ||||
| event: QueueNodeRetryEvent, | event: QueueNodeRetryEvent, | ||||
| task_id: str, | task_id: str, | ||||
| workflow_node_execution: WorkflowNodeExecution, | |||||
| workflow_node_execution: NodeExecution, | |||||
| ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: | ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: | ||||
| if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: | |||||
| if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: | |||||
| return None | return None | ||||
| if not workflow_node_execution.workflow_run_id: | if not workflow_node_execution.workflow_run_id: | ||||
| return None | return None | ||||
| index=workflow_node_execution.index, | index=workflow_node_execution.index, | ||||
| title=workflow_node_execution.title, | title=workflow_node_execution.title, | ||||
| predecessor_node_id=workflow_node_execution.predecessor_node_id, | predecessor_node_id=workflow_node_execution.predecessor_node_id, | ||||
| inputs=workflow_node_execution.inputs_dict, | |||||
| process_data=workflow_node_execution.process_data_dict, | |||||
| outputs=workflow_node_execution.outputs_dict, | |||||
| inputs=workflow_node_execution.inputs, | |||||
| process_data=workflow_node_execution.process_data, | |||||
| outputs=workflow_node_execution.outputs, | |||||
| status=workflow_node_execution.status, | status=workflow_node_execution.status, | ||||
| error=workflow_node_execution.error, | error=workflow_node_execution.error, | ||||
| elapsed_time=workflow_node_execution.elapsed_time, | elapsed_time=workflow_node_execution.elapsed_time, | ||||
| execution_metadata=workflow_node_execution.execution_metadata_dict, | |||||
| execution_metadata=workflow_node_execution.metadata, | |||||
| created_at=int(workflow_node_execution.created_at.timestamp()), | created_at=int(workflow_node_execution.created_at.timestamp()), | ||||
| finished_at=int(workflow_node_execution.finished_at.timestamp()), | finished_at=int(workflow_node_execution.finished_at.timestamp()), | ||||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), | |||||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), | |||||
| parallel_id=event.parallel_id, | parallel_id=event.parallel_id, | ||||
| parallel_start_node_id=event.parallel_start_node_id, | parallel_start_node_id=event.parallel_start_node_id, | ||||
| parent_parallel_id=event.parent_parallel_id, | parent_parallel_id=event.parent_parallel_id, | ||||
| return workflow_run | return workflow_run | ||||
| def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: | |||||
| # First check the cache for performance | |||||
| if node_execution_id in self._workflow_node_executions: | |||||
| cached_execution = self._workflow_node_executions[node_execution_id] | |||||
| # No need to merge with session since expire_on_commit=False | |||||
| return cached_execution | |||||
| # If not in cache, use the instance repository to get by node_execution_id | |||||
| execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id) | |||||
| if not execution: | |||||
| raise ValueError(f"Workflow node execution not found: {node_execution_id}") | |||||
| # Update cache | |||||
| self._workflow_node_executions[node_execution_id] = execution | |||||
| return execution | |||||
| def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: | def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: | ||||
| """ | """ | ||||
| Handle agent log | Handle agent log |
| Whitelist, | Whitelist, | ||||
| ) | ) | ||||
| from .engine import db | from .engine import db | ||||
| from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom | |||||
| from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom | |||||
| from .model import ( | from .model import ( | ||||
| ApiRequest, | ApiRequest, | ||||
| ApiToken, | ApiToken, | ||||
| "CeleryTaskSet", | "CeleryTaskSet", | ||||
| "Conversation", | "Conversation", | ||||
| "ConversationVariable", | "ConversationVariable", | ||||
| "CreatedByRole", | |||||
| "CreatorUserRole", | |||||
| "DataSourceApiKeyAuthBinding", | "DataSourceApiKeyAuthBinding", | ||||
| "DataSourceOauthBinding", | "DataSourceOauthBinding", | ||||
| "Dataset", | "Dataset", |
| from enum import StrEnum | from enum import StrEnum | ||||
| class CreatedByRole(StrEnum): | |||||
| class CreatorUserRole(StrEnum): | |||||
| ACCOUNT = "account" | ACCOUNT = "account" | ||||
| END_USER = "end_user" | END_USER = "end_user" | ||||
| from .account import Account, Tenant | from .account import Account, Tenant | ||||
| from .base import Base | from .base import Base | ||||
| from .engine import db | from .engine import db | ||||
| from .enums import CreatedByRole | |||||
| from .enums import CreatorUserRole | |||||
| from .types import StringUUID | from .types import StringUUID | ||||
| from .workflow import WorkflowRunStatus | from .workflow import WorkflowRunStatus | ||||
| url: str | None = None, | url: str | None = None, | ||||
| belongs_to: Literal["user", "assistant"] | None = None, | belongs_to: Literal["user", "assistant"] | None = None, | ||||
| upload_file_id: str | None = None, | upload_file_id: str | None = None, | ||||
| created_by_role: CreatedByRole, | |||||
| created_by_role: CreatorUserRole, | |||||
| created_by: str, | created_by: str, | ||||
| ): | ): | ||||
| self.message_id = message_id | self.message_id = message_id | ||||
| ) | ) | ||||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | ||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) | |||||
| app_id = db.Column(StringUUID, nullable=True) | app_id = db.Column(StringUUID, nullable=True) | ||||
| type = db.Column(db.String(255), nullable=False) | type = db.Column(db.String(255), nullable=False) | ||||
| external_user_id = db.Column(db.String(255), nullable=True) | external_user_id = db.Column(db.String(255), nullable=True) | ||||
| size: int, | size: int, | ||||
| extension: str, | extension: str, | ||||
| mime_type: str, | mime_type: str, | ||||
| created_by_role: CreatedByRole, | |||||
| created_by_role: CreatorUserRole, | |||||
| created_by: str, | created_by: str, | ||||
| created_at: datetime, | created_at: datetime, | ||||
| used: bool, | used: bool, |
| from .account import Account | from .account import Account | ||||
| from .base import Base | from .base import Base | ||||
| from .engine import db | from .engine import db | ||||
| from .enums import CreatedByRole | |||||
| from .enums import CreatorUserRole | |||||
| from .types import StringUUID | from .types import StringUUID | ||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| @property | @property | ||||
| def created_by_account(self): | def created_by_account(self): | ||||
| created_by_role = CreatedByRole(self.created_by_role) | |||||
| return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None | |||||
| created_by_role = CreatorUserRole(self.created_by_role) | |||||
| return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None | |||||
| @property | @property | ||||
| def created_by_end_user(self): | def created_by_end_user(self): | ||||
| from models.model import EndUser | from models.model import EndUser | ||||
| created_by_role = CreatedByRole(self.created_by_role) | |||||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None | |||||
| created_by_role = CreatorUserRole(self.created_by_role) | |||||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None | |||||
| @property | @property | ||||
| def graph_dict(self): | def graph_dict(self): | ||||
| @property | @property | ||||
| def created_by_account(self): | def created_by_account(self): | ||||
| created_by_role = CreatedByRole(self.created_by_role) | |||||
| created_by_role = CreatorUserRole(self.created_by_role) | |||||
| # TODO(-LAN-): Avoid using db.session.get() here. | # TODO(-LAN-): Avoid using db.session.get() here. | ||||
| return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None | |||||
| return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None | |||||
| @property | @property | ||||
| def created_by_end_user(self): | def created_by_end_user(self): | ||||
| from models.model import EndUser | from models.model import EndUser | ||||
| created_by_role = CreatedByRole(self.created_by_role) | |||||
| created_by_role = CreatorUserRole(self.created_by_role) | |||||
| # TODO(-LAN-): Avoid using db.session.get() here. | # TODO(-LAN-): Avoid using db.session.get() here. | ||||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None | |||||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None | |||||
| @property | @property | ||||
| def inputs_dict(self): | def inputs_dict(self): | ||||
| @property | @property | ||||
| def created_by_account(self): | def created_by_account(self): | ||||
| created_by_role = CreatedByRole(self.created_by_role) | |||||
| return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None | |||||
| created_by_role = CreatorUserRole(self.created_by_role) | |||||
| return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None | |||||
| @property | @property | ||||
| def created_by_end_user(self): | def created_by_end_user(self): | ||||
| from models.model import EndUser | from models.model import EndUser | ||||
| created_by_role = CreatedByRole(self.created_by_role) | |||||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None | |||||
| created_by_role = CreatorUserRole(self.created_by_role) | |||||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None | |||||
| class ConversationVariable(Base): | class ConversationVariable(Base): |
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| from models.account import Account | from models.account import Account | ||||
| from models.enums import CreatedByRole | |||||
| from models.enums import CreatorUserRole | |||||
| from models.model import EndUser, UploadFile | from models.model import EndUser, UploadFile | ||||
| from .errors.file import FileTooLargeError, UnsupportedFileTypeError | from .errors.file import FileTooLargeError, UnsupportedFileTypeError | ||||
| size=file_size, | size=file_size, | ||||
| extension=extension, | extension=extension, | ||||
| mime_type=mimetype, | mime_type=mimetype, | ||||
| created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), | |||||
| created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER), | |||||
| created_by=user.id, | created_by=user.id, | ||||
| created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | ||||
| used=False, | used=False, | ||||
| extension="txt", | extension="txt", | ||||
| mime_type="text/plain", | mime_type="text/plain", | ||||
| created_by=current_user.id, | created_by=current_user.id, | ||||
| created_by_role=CreatedByRole.ACCOUNT, | |||||
| created_by_role=CreatorUserRole.ACCOUNT, | |||||
| created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | ||||
| used=True, | used=True, | ||||
| used_by=current_user.id, | used_by=current_user.id, |
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from models import App, EndUser, WorkflowAppLog, WorkflowRun | from models import App, EndUser, WorkflowAppLog, WorkflowRun | ||||
| from models.enums import CreatedByRole | |||||
| from models.enums import CreatorUserRole | |||||
| from models.workflow import WorkflowRunStatus | from models.workflow import WorkflowRunStatus | ||||
| stmt = stmt.outerjoin( | stmt = stmt.outerjoin( | ||||
| EndUser, | EndUser, | ||||
| and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER), | |||||
| and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatorUserRole.END_USER), | |||||
| ).where(or_(*keyword_conditions)) | ).where(or_(*keyword_conditions)) | ||||
| if status: | if status: |
| import threading | import threading | ||||
| from collections.abc import Sequence | |||||
| from typing import Optional | from typing import Optional | ||||
| import contexts | import contexts | ||||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig | from core.workflow.repository.workflow_node_execution_repository import OrderConfig | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | from libs.infinite_scroll_pagination import InfiniteScrollPagination | ||||
| from models.enums import WorkflowRunTriggeredFrom | |||||
| from models.model import App | |||||
| from models.workflow import ( | |||||
| from models import ( | |||||
| Account, | |||||
| App, | |||||
| EndUser, | |||||
| WorkflowNodeExecution, | WorkflowNodeExecution, | ||||
| WorkflowRun, | WorkflowRun, | ||||
| WorkflowRunTriggeredFrom, | |||||
| ) | ) | ||||
| return workflow_run | return workflow_run | ||||
| def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]: | |||||
| def get_workflow_run_node_executions( | |||||
| self, | |||||
| app_model: App, | |||||
| run_id: str, | |||||
| user: Account | EndUser, | |||||
| ) -> Sequence[WorkflowNodeExecution]: | |||||
| """ | """ | ||||
| Get workflow run node execution list | Get workflow run node execution list | ||||
| """ | """ | ||||
| if not workflow_run: | if not workflow_run: | ||||
| return [] | return [] | ||||
| # Use the repository to get the node executions | |||||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id | |||||
| session_factory=db.engine, | |||||
| user=user, | |||||
| app_id=app_model.id, | |||||
| triggered_from=None, | |||||
| ) | ) | ||||
| # Use the repository to get the node executions with ordering | # Use the repository to get the node executions with ordering | ||||
| order_config = OrderConfig(order_by=["index"], order_direction="desc") | order_config = OrderConfig(order_by=["index"], order_direction="desc") | ||||
| node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) | |||||
| node_executions = repository.get_db_models_by_workflow_run(workflow_run_id=run_id, order_config=order_config) | |||||
| return list(node_executions) | |||||
| return node_executions |
| from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated | from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.account import Account | from models.account import Account | ||||
| from models.enums import CreatedByRole | |||||
| from models.enums import CreatorUserRole | |||||
| from models.model import App, AppMode | from models.model import App, AppMode | ||||
| from models.tools import WorkflowToolProvider | from models.tools import WorkflowToolProvider | ||||
| from models.workflow import ( | from models.workflow import ( | ||||
| workflow_node_execution.created_by = account.id | workflow_node_execution.created_by = account.id | ||||
| workflow_node_execution.workflow_id = draft_workflow.id | workflow_node_execution.workflow_id = draft_workflow.id | ||||
| # Use the repository to save the workflow node execution | |||||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id | |||||
| session_factory=db.engine, | |||||
| user=account, | |||||
| app_id=app_model.id, | |||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||||
| ) | ) | ||||
| repository.save(workflow_node_execution) | repository.save(workflow_node_execution) | ||||
| workflow_node_execution.node_type = node_instance.node_type | workflow_node_execution.node_type = node_instance.node_type | ||||
| workflow_node_execution.title = node_instance.node_data.title | workflow_node_execution.title = node_instance.node_data.title | ||||
| workflow_node_execution.elapsed_time = time.perf_counter() - start_at | workflow_node_execution.elapsed_time = time.perf_counter() - start_at | ||||
| workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value | |||||
| workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value | |||||
| workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) | workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) | ||||
| workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) | workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) | ||||
| if run_succeeded and node_run_result: | if run_succeeded and node_run_result: |
| import click | import click | ||||
| from celery import shared_task # type: ignore | from celery import shared_task # type: ignore | ||||
| from sqlalchemy import delete | |||||
| from sqlalchemy import delete, select | |||||
| from sqlalchemy.exc import SQLAlchemyError | from sqlalchemy.exc import SQLAlchemyError | ||||
| from sqlalchemy.orm import Session | |||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import AppDatasetJoin | |||||
| from models.model import ( | |||||
| from models import ( | |||||
| Account, | |||||
| ApiToken, | ApiToken, | ||||
| App, | |||||
| AppAnnotationHitHistory, | AppAnnotationHitHistory, | ||||
| AppAnnotationSetting, | AppAnnotationSetting, | ||||
| AppDatasetJoin, | |||||
| AppModelConfig, | AppModelConfig, | ||||
| Conversation, | Conversation, | ||||
| EndUser, | EndUser, | ||||
| def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): | def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): | ||||
| # Get app's owner | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id) | |||||
| user = session.scalar(stmt) | |||||
| if user is None: | |||||
| errmsg = ( | |||||
| f"Failed to delete workflow node executions for tenant {tenant_id} and app {app_id}, app's owner not found" | |||||
| ) | |||||
| logging.error(errmsg) | |||||
| raise ValueError(errmsg) | |||||
| # Create a repository instance for WorkflowNodeExecution | # Create a repository instance for WorkflowNodeExecution | ||||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | repository = SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=db.engine, tenant_id=tenant_id, app_id=app_id | |||||
| session_factory=db.engine, | |||||
| user=user, | |||||
| app_id=app_id, | |||||
| triggered_from=None, | |||||
| ) | ) | ||||
| # Use the clear method to delete all records for this tenant_id and app_id | # Use the clear method to delete all records for this tenant_id and app_id |
| from core.workflow.nodes import NodeType | from core.workflow.nodes import NodeType | ||||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | ||||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | from core.workflow.workflow_cycle_manager import WorkflowCycleManager | ||||
| from models.enums import CreatedByRole | |||||
| from models.enums import CreatorUserRole | |||||
| from models.workflow import ( | from models.workflow import ( | ||||
| Workflow, | Workflow, | ||||
| WorkflowNodeExecution, | |||||
| WorkflowNodeExecutionStatus, | WorkflowNodeExecutionStatus, | ||||
| WorkflowRun, | WorkflowRun, | ||||
| WorkflowRunStatus, | WorkflowRunStatus, | ||||
| workflow_run.app_id = "test-app-id" | workflow_run.app_id = "test-app-id" | ||||
| workflow_run.workflow_id = "test-workflow-id" | workflow_run.workflow_id = "test-workflow-id" | ||||
| workflow_run.status = WorkflowRunStatus.RUNNING | workflow_run.status = WorkflowRunStatus.RUNNING | ||||
| workflow_run.created_by_role = CreatedByRole.ACCOUNT | |||||
| workflow_run.created_by_role = CreatorUserRole.ACCOUNT | |||||
| workflow_run.created_by = "test-user-id" | workflow_run.created_by = "test-user-id" | ||||
| workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) | workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) | ||||
| workflow_run.inputs_dict = {"query": "test query"} | workflow_run.inputs_dict = {"query": "test query"} | ||||
| ): | ): | ||||
| """Test initialization of WorkflowCycleManager""" | """Test initialization of WorkflowCycleManager""" | ||||
| assert workflow_cycle_manager._workflow_run is None | assert workflow_cycle_manager._workflow_run is None | ||||
| assert workflow_cycle_manager._workflow_node_executions == {} | |||||
| assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity | assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity | ||||
| assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables | assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables | ||||
| assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository | assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository | ||||
| session=mock_session, | session=mock_session, | ||||
| workflow_id="test-workflow-id", | workflow_id="test-workflow-id", | ||||
| user_id="test-user-id", | user_id="test-user-id", | ||||
| created_by_role=CreatedByRole.ACCOUNT, | |||||
| created_by_role=CreatorUserRole.ACCOUNT, | |||||
| ) | ) | ||||
| # Verify the result | # Verify the result | ||||
| assert workflow_run.workflow_id == mock_workflow.id | assert workflow_run.workflow_id == mock_workflow.id | ||||
| assert workflow_run.sequence_number == 6 # max_sequence + 1 | assert workflow_run.sequence_number == 6 # max_sequence + 1 | ||||
| assert workflow_run.status == WorkflowRunStatus.RUNNING | assert workflow_run.status == WorkflowRunStatus.RUNNING | ||||
| assert workflow_run.created_by_role == CreatedByRole.ACCOUNT | |||||
| assert workflow_run.created_by_role == CreatorUserRole.ACCOUNT | |||||
| assert workflow_run.created_by == "test-user-id" | assert workflow_run.created_by == "test-user-id" | ||||
| # Verify session.add was called | # Verify session.add was called | ||||
| ) | ) | ||||
| # Verify the result | # Verify the result | ||||
| assert result.tenant_id == mock_workflow_run.tenant_id | |||||
| assert result.app_id == mock_workflow_run.app_id | |||||
| # NodeExecution doesn't have tenant_id attribute, it's handled at repository level | |||||
| # assert result.tenant_id == mock_workflow_run.tenant_id | |||||
| # assert result.app_id == mock_workflow_run.app_id | |||||
| assert result.workflow_id == mock_workflow_run.workflow_id | assert result.workflow_id == mock_workflow_run.workflow_id | ||||
| assert result.workflow_run_id == mock_workflow_run.id | assert result.workflow_run_id == mock_workflow_run.id | ||||
| assert result.node_execution_id == event.node_execution_id | assert result.node_execution_id == event.node_execution_id | ||||
| assert result.node_id == event.node_id | assert result.node_id == event.node_id | ||||
| assert result.node_type == event.node_type.value | |||||
| assert result.node_type == event.node_type | |||||
| assert result.title == event.node_data.title | assert result.title == event.node_data.title | ||||
| assert result.status == WorkflowNodeExecutionStatus.RUNNING.value | assert result.status == WorkflowNodeExecutionStatus.RUNNING.value | ||||
| assert result.created_by_role == mock_workflow_run.created_by_role | |||||
| assert result.created_by == mock_workflow_run.created_by | |||||
| # NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level | |||||
| # assert result.created_by_role == mock_workflow_run.created_by_role | |||||
| # assert result.created_by == mock_workflow_run.created_by | |||||
| # Verify save was called | # Verify save was called | ||||
| workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result) | workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result) | ||||
| # Verify the node execution was added to the cache | |||||
| assert workflow_cycle_manager._workflow_node_executions[event.node_execution_id] == result | |||||
| def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run): | def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run): | ||||
| """Test _get_workflow_run method""" | """Test _get_workflow_run method""" | ||||
| event.execution_metadata = {"metadata": "test metadata"} | event.execution_metadata = {"metadata": "test metadata"} | ||||
| event.start_at = datetime.now(UTC).replace(tzinfo=None) | event.start_at = datetime.now(UTC).replace(tzinfo=None) | ||||
| # Create a mock workflow node execution | |||||
| node_execution = MagicMock(spec=WorkflowNodeExecution) | |||||
| # Create a mock node execution | |||||
| node_execution = MagicMock() | |||||
| node_execution.node_execution_id = "test-node-execution-id" | node_execution.node_execution_id = "test-node-execution-id" | ||||
| # Mock _get_workflow_node_execution to return the mock node execution | |||||
| with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): | |||||
| # Call the method | |||||
| result = workflow_cycle_manager._handle_workflow_node_execution_success( | |||||
| event=event, | |||||
| ) | |||||
| # Mock the repository to return the node execution | |||||
| workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution | |||||
| # Verify the result | |||||
| assert result == node_execution | |||||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value | |||||
| assert result.inputs == json.dumps(event.inputs) | |||||
| assert result.process_data == json.dumps(event.process_data) | |||||
| assert result.outputs == json.dumps(event.outputs) | |||||
| assert result.finished_at is not None | |||||
| assert result.elapsed_time is not None | |||||
| # Call the method | |||||
| result = workflow_cycle_manager._handle_workflow_node_execution_success( | |||||
| event=event, | |||||
| ) | |||||
| # Verify update was called | |||||
| workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) | |||||
| # Verify the result | |||||
| assert result == node_execution | |||||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value | |||||
| # Verify save was called | |||||
| workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) | |||||
| def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run): | def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run): | ||||
| event.start_at = datetime.now(UTC).replace(tzinfo=None) | event.start_at = datetime.now(UTC).replace(tzinfo=None) | ||||
| event.error = "Test error message" | event.error = "Test error message" | ||||
| # Create a mock workflow node execution | |||||
| node_execution = MagicMock(spec=WorkflowNodeExecution) | |||||
| # Create a mock node execution | |||||
| node_execution = MagicMock() | |||||
| node_execution.node_execution_id = "test-node-execution-id" | node_execution.node_execution_id = "test-node-execution-id" | ||||
| # Mock _get_workflow_node_execution to return the mock node execution | |||||
| with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): | |||||
| # Call the method | |||||
| result = workflow_cycle_manager._handle_workflow_node_execution_failed( | |||||
| event=event, | |||||
| ) | |||||
| # Mock the repository to return the node execution | |||||
| workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution | |||||
| # Verify the result | |||||
| assert result == node_execution | |||||
| assert result.status == WorkflowNodeExecutionStatus.FAILED.value | |||||
| assert result.error == "Test error message" | |||||
| assert result.inputs == json.dumps(event.inputs) | |||||
| assert result.process_data == json.dumps(event.process_data) | |||||
| assert result.outputs == json.dumps(event.outputs) | |||||
| assert result.finished_at is not None | |||||
| assert result.elapsed_time is not None | |||||
| assert result.execution_metadata == json.dumps(event.execution_metadata) | |||||
| # Call the method | |||||
| result = workflow_cycle_manager._handle_workflow_node_execution_failed( | |||||
| event=event, | |||||
| ) | |||||
| # Verify update was called | |||||
| workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) | |||||
| # Verify the result | |||||
| assert result == node_execution | |||||
| assert result.status == WorkflowNodeExecutionStatus.FAILED.value | |||||
| assert result.error == "Test error message" | |||||
| # Verify save was called | |||||
| workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) |
| Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository. | Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository. | ||||
| """ | """ | ||||
| from unittest.mock import MagicMock | |||||
| import json | |||||
| from datetime import datetime | |||||
| from unittest.mock import MagicMock, PropertyMock | |||||
| import pytest | import pytest | ||||
| from pytest_mock import MockerFixture | from pytest_mock import MockerFixture | ||||
| from sqlalchemy.orm import Session, sessionmaker | from sqlalchemy.orm import Session, sessionmaker | ||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | ||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||||
| from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus | |||||
| from core.workflow.nodes.enums import NodeType | |||||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig | from core.workflow.repository.workflow_node_execution_repository import OrderConfig | ||||
| from models.workflow import WorkflowNodeExecution | |||||
| from models.account import Account, Tenant | |||||
| from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom | |||||
| def configure_mock_execution(mock_execution): | |||||
| """Configure a mock execution with proper JSON serializable values.""" | |||||
| # Configure inputs, outputs, process_data, and execution_metadata to return JSON serializable values | |||||
| type(mock_execution).inputs = PropertyMock(return_value='{"key": "value"}') | |||||
| type(mock_execution).outputs = PropertyMock(return_value='{"result": "success"}') | |||||
| type(mock_execution).process_data = PropertyMock(return_value='{"process": "data"}') | |||||
| type(mock_execution).execution_metadata = PropertyMock(return_value='{"metadata": "info"}') | |||||
| # Configure status and triggered_from to be valid enum values | |||||
| mock_execution.status = "running" | |||||
| mock_execution.triggered_from = "workflow-run" | |||||
| return mock_execution | |||||
| @pytest.fixture | @pytest.fixture | ||||
| @pytest.fixture | @pytest.fixture | ||||
| def repository(session): | |||||
| def mock_user(): | |||||
| """Create a user instance for testing.""" | |||||
| user = Account() | |||||
| user.id = "test-user-id" | |||||
| tenant = Tenant() | |||||
| tenant.id = "test-tenant" | |||||
| tenant.name = "Test Workspace" | |||||
| user._current_tenant = MagicMock() | |||||
| user._current_tenant.id = "test-tenant" | |||||
| return user | |||||
| @pytest.fixture | |||||
| def repository(session, mock_user): | |||||
| """Create a repository instance with test data.""" | """Create a repository instance with test data.""" | ||||
| _, session_factory = session | _, session_factory = session | ||||
| tenant_id = "test-tenant" | |||||
| app_id = "test-app" | app_id = "test-app" | ||||
| return SQLAlchemyWorkflowNodeExecutionRepository( | return SQLAlchemyWorkflowNodeExecutionRepository( | ||||
| session_factory=session_factory, tenant_id=tenant_id, app_id=app_id | |||||
| session_factory=session_factory, | |||||
| user=mock_user, | |||||
| app_id=app_id, | |||||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||||
| ) | ) | ||||
| execution = MagicMock(spec=WorkflowNodeExecution) | execution = MagicMock(spec=WorkflowNodeExecution) | ||||
| execution.tenant_id = None | execution.tenant_id = None | ||||
| execution.app_id = None | execution.app_id = None | ||||
| execution.inputs = None | |||||
| execution.process_data = None | |||||
| execution.outputs = None | |||||
| execution.metadata = None | |||||
| # Mock the _to_db_model method to return the execution itself | |||||
| # This simulates the behavior of setting tenant_id and app_id | |||||
| repository._to_db_model = MagicMock(return_value=execution) | |||||
| # Call save method | # Call save method | ||||
| repository.save(execution) | repository.save(execution) | ||||
| # Assert tenant_id and app_id are set | |||||
| assert execution.tenant_id == repository._tenant_id | |||||
| assert execution.app_id == repository._app_id | |||||
| # Assert _to_db_model was called with the execution | |||||
| repository._to_db_model.assert_called_once_with(execution) | |||||
| # Assert session.add was called | |||||
| session_obj.add.assert_called_once_with(execution) | |||||
| # Assert session.merge was called (now using merge for both save and update) | |||||
| session_obj.merge.assert_called_once_with(execution) | |||||
| def test_save_with_existing_tenant_id(repository, session): | def test_save_with_existing_tenant_id(repository, session): | ||||
| execution = MagicMock(spec=WorkflowNodeExecution) | execution = MagicMock(spec=WorkflowNodeExecution) | ||||
| execution.tenant_id = "existing-tenant" | execution.tenant_id = "existing-tenant" | ||||
| execution.app_id = None | execution.app_id = None | ||||
| execution.inputs = None | |||||
| execution.process_data = None | |||||
| execution.outputs = None | |||||
| execution.metadata = None | |||||
| # Create a modified execution that will be returned by _to_db_model | |||||
| modified_execution = MagicMock(spec=WorkflowNodeExecution) | |||||
| modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change | |||||
| modified_execution.app_id = repository._app_id # App ID should be set | |||||
| # Mock the _to_db_model method to return the modified execution | |||||
| repository._to_db_model = MagicMock(return_value=modified_execution) | |||||
| # Call save method | # Call save method | ||||
| repository.save(execution) | repository.save(execution) | ||||
| # Assert tenant_id is not changed and app_id is set | |||||
| assert execution.tenant_id == "existing-tenant" | |||||
| assert execution.app_id == repository._app_id | |||||
| # Assert _to_db_model was called with the execution | |||||
| repository._to_db_model.assert_called_once_with(execution) | |||||
| # Assert session.add was called | |||||
| session_obj.add.assert_called_once_with(execution) | |||||
| # Assert session.merge was called with the modified execution (now using merge for both save and update) | |||||
| session_obj.merge.assert_called_once_with(modified_execution) | |||||
| def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): | def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): | ||||
| mock_stmt = mocker.MagicMock() | mock_stmt = mocker.MagicMock() | ||||
| mock_select.return_value = mock_stmt | mock_select.return_value = mock_stmt | ||||
| mock_stmt.where.return_value = mock_stmt | mock_stmt.where.return_value = mock_stmt | ||||
| session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution) | |||||
| # Create a properly configured mock execution | |||||
| mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) | |||||
| configure_mock_execution(mock_execution) | |||||
| session_obj.scalar.return_value = mock_execution | |||||
| # Create a mock domain model to be returned by _to_domain_model | |||||
| mock_domain_model = mocker.MagicMock() | |||||
| # Mock the _to_domain_model method to return our mock domain model | |||||
| repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) | |||||
| # Call method | # Call method | ||||
| result = repository.get_by_node_execution_id("test-node-execution-id") | result = repository.get_by_node_execution_id("test-node-execution-id") | ||||
| # Assert select was called with correct parameters | # Assert select was called with correct parameters | ||||
| mock_select.assert_called_once() | mock_select.assert_called_once() | ||||
| session_obj.scalar.assert_called_once_with(mock_stmt) | session_obj.scalar.assert_called_once_with(mock_stmt) | ||||
| assert result is not None | |||||
| # Assert _to_domain_model was called with the mock execution | |||||
| repository._to_domain_model.assert_called_once_with(mock_execution) | |||||
| # Assert the result is our mock domain model | |||||
| assert result is mock_domain_model | |||||
| def test_get_by_workflow_run(repository, session, mocker: MockerFixture): | def test_get_by_workflow_run(repository, session, mocker: MockerFixture): | ||||
| mock_select.return_value = mock_stmt | mock_select.return_value = mock_stmt | ||||
| mock_stmt.where.return_value = mock_stmt | mock_stmt.where.return_value = mock_stmt | ||||
| mock_stmt.order_by.return_value = mock_stmt | mock_stmt.order_by.return_value = mock_stmt | ||||
| session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] | |||||
| # Create a properly configured mock execution | |||||
| mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) | |||||
| configure_mock_execution(mock_execution) | |||||
| session_obj.scalars.return_value.all.return_value = [mock_execution] | |||||
| # Create a mock domain model to be returned by _to_domain_model | |||||
| mock_domain_model = mocker.MagicMock() | |||||
| # Mock the _to_domain_model method to return our mock domain model | |||||
| repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) | |||||
| # Call method | # Call method | ||||
| order_config = OrderConfig(order_by=["index"], order_direction="desc") | order_config = OrderConfig(order_by=["index"], order_direction="desc") | ||||
| # Assert select was called with correct parameters | # Assert select was called with correct parameters | ||||
| mock_select.assert_called_once() | mock_select.assert_called_once() | ||||
| session_obj.scalars.assert_called_once_with(mock_stmt) | session_obj.scalars.assert_called_once_with(mock_stmt) | ||||
| # Assert _to_domain_model was called with the mock execution | |||||
| repository._to_domain_model.assert_called_once_with(mock_execution) | |||||
| # Assert the result contains our mock domain model | |||||
| assert len(result) == 1 | |||||
| assert result[0] is mock_domain_model | |||||
| def test_get_db_models_by_workflow_run(repository, session, mocker: MockerFixture): | |||||
| """Test get_db_models_by_workflow_run method.""" | |||||
| session_obj, _ = session | |||||
| # Set up mock | |||||
| mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") | |||||
| mock_stmt = mocker.MagicMock() | |||||
| mock_select.return_value = mock_stmt | |||||
| mock_stmt.where.return_value = mock_stmt | |||||
| mock_stmt.order_by.return_value = mock_stmt | |||||
| # Create a properly configured mock execution | |||||
| mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) | |||||
| configure_mock_execution(mock_execution) | |||||
| session_obj.scalars.return_value.all.return_value = [mock_execution] | |||||
| # Mock the _to_domain_model method | |||||
| to_domain_model_mock = mocker.patch.object(repository, "_to_domain_model") | |||||
| # Call method | |||||
| order_config = OrderConfig(order_by=["index"], order_direction="desc") | |||||
| result = repository.get_db_models_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config) | |||||
| # Assert select was called with correct parameters | |||||
| mock_select.assert_called_once() | |||||
| session_obj.scalars.assert_called_once_with(mock_stmt) | |||||
| # Assert the result contains our mock db model directly (without conversion to domain model) | |||||
| assert len(result) == 1 | assert len(result) == 1 | ||||
| assert result[0] is mock_execution | |||||
| # Verify that _to_domain_model was NOT called (since we're returning raw DB models) | |||||
| to_domain_model_mock.assert_not_called() | |||||
| def test_get_running_executions(repository, session, mocker: MockerFixture): | def test_get_running_executions(repository, session, mocker: MockerFixture): | ||||
| mock_stmt = mocker.MagicMock() | mock_stmt = mocker.MagicMock() | ||||
| mock_select.return_value = mock_stmt | mock_select.return_value = mock_stmt | ||||
| mock_stmt.where.return_value = mock_stmt | mock_stmt.where.return_value = mock_stmt | ||||
| session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] | |||||
| # Create a properly configured mock execution | |||||
| mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) | |||||
| configure_mock_execution(mock_execution) | |||||
| session_obj.scalars.return_value.all.return_value = [mock_execution] | |||||
| # Create a mock domain model to be returned by _to_domain_model | |||||
| mock_domain_model = mocker.MagicMock() | |||||
| # Mock the _to_domain_model method to return our mock domain model | |||||
| repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) | |||||
| # Call method | # Call method | ||||
| result = repository.get_running_executions("test-workflow-run-id") | result = repository.get_running_executions("test-workflow-run-id") | ||||
| # Assert select was called with correct parameters | # Assert select was called with correct parameters | ||||
| mock_select.assert_called_once() | mock_select.assert_called_once() | ||||
| session_obj.scalars.assert_called_once_with(mock_stmt) | session_obj.scalars.assert_called_once_with(mock_stmt) | ||||
| # Assert _to_domain_model was called with the mock execution | |||||
| repository._to_domain_model.assert_called_once_with(mock_execution) | |||||
| # Assert the result contains our mock domain model | |||||
| assert len(result) == 1 | assert len(result) == 1 | ||||
| assert result[0] is mock_domain_model | |||||
| def test_update(repository, session): | |||||
| """Test update method.""" | |||||
| def test_update_via_save(repository, session): | |||||
| """Test updating an existing record via save method.""" | |||||
| session_obj, _ = session | session_obj, _ = session | ||||
| # Create a mock execution | # Create a mock execution | ||||
| execution = MagicMock(spec=WorkflowNodeExecution) | execution = MagicMock(spec=WorkflowNodeExecution) | ||||
| execution.tenant_id = None | execution.tenant_id = None | ||||
| execution.app_id = None | execution.app_id = None | ||||
| execution.inputs = None | |||||
| execution.process_data = None | |||||
| execution.outputs = None | |||||
| execution.metadata = None | |||||
| # Call update method | |||||
| repository.update(execution) | |||||
| # Mock the _to_db_model method to return the execution itself | |||||
| # This simulates the behavior of setting tenant_id and app_id | |||||
| repository._to_db_model = MagicMock(return_value=execution) | |||||
| # Assert tenant_id and app_id are set | |||||
| assert execution.tenant_id == repository._tenant_id | |||||
| assert execution.app_id == repository._app_id | |||||
| # Call save method to update an existing record | |||||
| repository.save(execution) | |||||
| # Assert session.merge was called | |||||
| # Assert _to_db_model was called with the execution | |||||
| repository._to_db_model.assert_called_once_with(execution) | |||||
| # Assert session.merge was called (for updates) | |||||
| session_obj.merge.assert_called_once_with(execution) | session_obj.merge.assert_called_once_with(execution) | ||||
| mock_stmt.where.assert_called() | mock_stmt.where.assert_called() | ||||
| session_obj.execute.assert_called_once_with(mock_stmt) | session_obj.execute.assert_called_once_with(mock_stmt) | ||||
| session_obj.commit.assert_called_once() | session_obj.commit.assert_called_once() | ||||
| def test_to_db_model(repository): | |||||
| """Test _to_db_model method.""" | |||||
| # Create a domain model | |||||
| domain_model = NodeExecution( | |||||
| id="test-id", | |||||
| workflow_id="test-workflow-id", | |||||
| node_execution_id="test-node-execution-id", | |||||
| workflow_run_id="test-workflow-run-id", | |||||
| index=1, | |||||
| predecessor_node_id="test-predecessor-id", | |||||
| node_id="test-node-id", | |||||
| node_type=NodeType.START, | |||||
| title="Test Node", | |||||
| inputs={"input_key": "input_value"}, | |||||
| process_data={"process_key": "process_value"}, | |||||
| outputs={"output_key": "output_value"}, | |||||
| status=NodeExecutionStatus.RUNNING, | |||||
| error=None, | |||||
| elapsed_time=1.5, | |||||
| metadata={NodeRunMetadataKey.TOTAL_TOKENS: 100}, | |||||
| created_at=datetime.now(), | |||||
| finished_at=None, | |||||
| ) | |||||
| # Convert to DB model | |||||
| db_model = repository._to_db_model(domain_model) | |||||
| # Assert DB model has correct values | |||||
| assert isinstance(db_model, WorkflowNodeExecution) | |||||
| assert db_model.id == domain_model.id | |||||
| assert db_model.tenant_id == repository._tenant_id | |||||
| assert db_model.app_id == repository._app_id | |||||
| assert db_model.workflow_id == domain_model.workflow_id | |||||
| assert db_model.triggered_from == repository._triggered_from | |||||
| assert db_model.workflow_run_id == domain_model.workflow_run_id | |||||
| assert db_model.index == domain_model.index | |||||
| assert db_model.predecessor_node_id == domain_model.predecessor_node_id | |||||
| assert db_model.node_execution_id == domain_model.node_execution_id | |||||
| assert db_model.node_id == domain_model.node_id | |||||
| assert db_model.node_type == domain_model.node_type | |||||
| assert db_model.title == domain_model.title | |||||
| assert db_model.inputs_dict == domain_model.inputs | |||||
| assert db_model.process_data_dict == domain_model.process_data | |||||
| assert db_model.outputs_dict == domain_model.outputs | |||||
| assert db_model.execution_metadata_dict == domain_model.metadata | |||||
| assert db_model.status == domain_model.status | |||||
| assert db_model.error == domain_model.error | |||||
| assert db_model.elapsed_time == domain_model.elapsed_time | |||||
| assert db_model.created_at == domain_model.created_at | |||||
| assert db_model.created_by_role == repository._creator_user_role | |||||
| assert db_model.created_by == repository._creator_user_id | |||||
| assert db_model.finished_at == domain_model.finished_at | |||||
| def test_to_domain_model(repository): | |||||
| """Test _to_domain_model method.""" | |||||
| # Create input dictionaries | |||||
| inputs_dict = {"input_key": "input_value"} | |||||
| process_data_dict = {"process_key": "process_value"} | |||||
| outputs_dict = {"output_key": "output_value"} | |||||
| metadata_dict = {str(NodeRunMetadataKey.TOTAL_TOKENS): 100} | |||||
| # Create a DB model using our custom subclass | |||||
| db_model = WorkflowNodeExecution() | |||||
| db_model.id = "test-id" | |||||
| db_model.tenant_id = "test-tenant-id" | |||||
| db_model.app_id = "test-app-id" | |||||
| db_model.workflow_id = "test-workflow-id" | |||||
| db_model.triggered_from = "workflow-run" | |||||
| db_model.workflow_run_id = "test-workflow-run-id" | |||||
| db_model.index = 1 | |||||
| db_model.predecessor_node_id = "test-predecessor-id" | |||||
| db_model.node_execution_id = "test-node-execution-id" | |||||
| db_model.node_id = "test-node-id" | |||||
| db_model.node_type = NodeType.START.value | |||||
| db_model.title = "Test Node" | |||||
| db_model.inputs = json.dumps(inputs_dict) | |||||
| db_model.process_data = json.dumps(process_data_dict) | |||||
| db_model.outputs = json.dumps(outputs_dict) | |||||
| db_model.status = WorkflowNodeExecutionStatus.RUNNING | |||||
| db_model.error = None | |||||
| db_model.elapsed_time = 1.5 | |||||
| db_model.execution_metadata = json.dumps(metadata_dict) | |||||
| db_model.created_at = datetime.now() | |||||
| db_model.created_by_role = "account" | |||||
| db_model.created_by = "test-user-id" | |||||
| db_model.finished_at = None | |||||
| # Convert to domain model | |||||
| domain_model = repository._to_domain_model(db_model) | |||||
| # Assert domain model has correct values | |||||
| assert isinstance(domain_model, NodeExecution) | |||||
| assert domain_model.id == db_model.id | |||||
| assert domain_model.workflow_id == db_model.workflow_id | |||||
| assert domain_model.workflow_run_id == db_model.workflow_run_id | |||||
| assert domain_model.index == db_model.index | |||||
| assert domain_model.predecessor_node_id == db_model.predecessor_node_id | |||||
| assert domain_model.node_execution_id == db_model.node_execution_id | |||||
| assert domain_model.node_id == db_model.node_id | |||||
| assert domain_model.node_type == NodeType(db_model.node_type) | |||||
| assert domain_model.title == db_model.title | |||||
| assert domain_model.inputs == inputs_dict | |||||
| assert domain_model.process_data == process_data_dict | |||||
| assert domain_model.outputs == outputs_dict | |||||
| assert domain_model.status == NodeExecutionStatus(db_model.status) | |||||
| assert domain_model.error == db_model.error | |||||
| assert domain_model.elapsed_time == db_model.elapsed_time | |||||
| assert domain_model.metadata == metadata_dict | |||||
| assert domain_model.created_at == db_model.created_at | |||||
| assert domain_model.finished_at == db_model.finished_at |