Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>tags/1.4.1
| @@ -1,3 +1,6 @@ | |||
| from typing import cast | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, marshal_with, reqparse | |||
| from flask_restful.inputs import int_range | |||
| @@ -12,8 +15,7 @@ from fields.workflow_run_fields import ( | |||
| ) | |||
| from libs.helper import uuid_value | |||
| from libs.login import login_required | |||
| from models import App | |||
| from models.model import AppMode | |||
| from models import Account, App, AppMode, EndUser | |||
| from services.workflow_run_service import WorkflowRunService | |||
| @@ -90,7 +92,12 @@ class WorkflowRunNodeExecutionListApi(Resource): | |||
| run_id = str(run_id) | |||
| workflow_run_service = WorkflowRunService() | |||
| node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) | |||
| user = cast("Account | EndUser", current_user) | |||
| node_executions = workflow_run_service.get_workflow_run_node_executions( | |||
| app_model=app_model, | |||
| run_id=run_id, | |||
| user=user, | |||
| ) | |||
| return {"data": node_executions} | |||
| @@ -29,9 +29,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models.account import Account | |||
| from models.model import App, Conversation, EndUser, Message | |||
| from models.workflow import Workflow | |||
| from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| from services.conversation_service import ConversationService | |||
| from services.errors.message import MessageNotExistsError | |||
| @@ -165,8 +163,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| return self._generate( | |||
| @@ -231,8 +230,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| return self._generate( | |||
| @@ -295,8 +295,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| return self._generate( | |||
| @@ -70,7 +70,7 @@ from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models import Conversation, EndUser, Message, MessageFile | |||
| from models.account import Account | |||
| from models.enums import CreatedByRole | |||
| from models.enums import CreatorUserRole | |||
| from models.workflow import ( | |||
| Workflow, | |||
| WorkflowRunStatus, | |||
| @@ -105,11 +105,11 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| if isinstance(user, EndUser): | |||
| self._user_id = user.id | |||
| user_session_id = user.session_id | |||
| self._created_by_role = CreatedByRole.END_USER | |||
| self._created_by_role = CreatorUserRole.END_USER | |||
| elif isinstance(user, Account): | |||
| self._user_id = user.id | |||
| user_session_id = user.id | |||
| self._created_by_role = CreatedByRole.ACCOUNT | |||
| self._created_by_role = CreatorUserRole.ACCOUNT | |||
| else: | |||
| raise NotImplementedError(f"User type not supported: {type(user)}") | |||
| @@ -739,9 +739,9 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| url=file["remote_url"], | |||
| belongs_to="assistant", | |||
| upload_file_id=file["related_id"], | |||
| created_by_role=CreatedByRole.ACCOUNT | |||
| created_by_role=CreatorUserRole.ACCOUNT | |||
| if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | |||
| else CreatedByRole.END_USER, | |||
| else CreatorUserRole.END_USER, | |||
| created_by=message.from_account_id or message.from_end_user_id or "", | |||
| ) | |||
| for file in self._recorded_files | |||
| @@ -25,7 +25,7 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from extensions.ext_database import db | |||
| from models import Account | |||
| from models.enums import CreatedByRole | |||
| from models.enums import CreatorUserRole | |||
| from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile | |||
| from services.errors.app_model_config import AppModelConfigBrokenError | |||
| from services.errors.conversation import ConversationNotExistsError | |||
| @@ -223,7 +223,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| belongs_to="user", | |||
| url=file.remote_url, | |||
| upload_file_id=file.related_id, | |||
| created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER), | |||
| created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), | |||
| created_by=account_id or end_user_id or "", | |||
| ) | |||
| db.session.add(message_file) | |||
| @@ -27,7 +27,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow | |||
| from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models import Account, App, EndUser, Workflow | |||
| from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @@ -138,10 +138,12 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| return self._generate( | |||
| @@ -262,10 +264,12 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| return self._generate( | |||
| @@ -325,10 +329,12 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| return self._generate( | |||
| @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict | |||
| from core.model_runtime.entities.llm_entities import LLMResult | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| @@ -244,7 +244,7 @@ class NodeStartStreamResponse(StreamResponse): | |||
| title: str | |||
| index: int | |||
| predecessor_node_id: Optional[str] = None | |||
| inputs: Optional[dict] = None | |||
| inputs: Optional[Mapping[str, Any]] = None | |||
| created_at: int | |||
| extras: dict = {} | |||
| parallel_id: Optional[str] = None | |||
| @@ -301,13 +301,13 @@ class NodeFinishStreamResponse(StreamResponse): | |||
| title: str | |||
| index: int | |||
| predecessor_node_id: Optional[str] = None | |||
| inputs: Optional[dict] = None | |||
| process_data: Optional[dict] = None | |||
| outputs: Optional[dict] = None | |||
| inputs: Optional[Mapping[str, Any]] = None | |||
| process_data: Optional[Mapping[str, Any]] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| status: str | |||
| error: Optional[str] = None | |||
| elapsed_time: float | |||
| execution_metadata: Optional[dict] = None | |||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||
| created_at: int | |||
| finished_at: int | |||
| files: Optional[Sequence[Mapping[str, Any]]] = [] | |||
| @@ -370,13 +370,13 @@ class NodeRetryStreamResponse(StreamResponse): | |||
| title: str | |||
| index: int | |||
| predecessor_node_id: Optional[str] = None | |||
| inputs: Optional[dict] = None | |||
| process_data: Optional[dict] = None | |||
| outputs: Optional[dict] = None | |||
| inputs: Optional[Mapping[str, Any]] = None | |||
| process_data: Optional[Mapping[str, Any]] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| status: str | |||
| error: Optional[str] = None | |||
| elapsed_time: float | |||
| execution_metadata: Optional[dict] = None | |||
| execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None | |||
| created_at: int | |||
| finished_at: int | |||
| files: Optional[Sequence[Mapping[str, Any]]] = [] | |||
| @@ -1,3 +1,4 @@ | |||
| from collections.abc import Mapping | |||
| from datetime import datetime | |||
| from enum import StrEnum | |||
| from typing import Any, Optional, Union | |||
| @@ -155,10 +156,10 @@ class LangfuseSpan(BaseModel): | |||
| description="The status message of the span. Additional field for context of the event. E.g. the error " | |||
| "message of an error event.", | |||
| ) | |||
| input: Optional[Union[str, dict[str, Any], list, None]] = Field( | |||
| input: Optional[Union[str, Mapping[str, Any], list, None]] = Field( | |||
| default=None, description="The input of the span. Can be any JSON object." | |||
| ) | |||
| output: Optional[Union[str, dict[str, Any], list, None]] = Field( | |||
| output: Optional[Union[str, Mapping[str, Any], list, None]] = Field( | |||
| default=None, description="The output of the span. Can be any JSON object." | |||
| ) | |||
| version: Optional[str] = Field( | |||
| @@ -1,11 +1,10 @@ | |||
| import json | |||
| import logging | |||
| import os | |||
| from datetime import datetime, timedelta | |||
| from typing import Optional | |||
| from langfuse import Langfuse # type: ignore | |||
| from sqlalchemy.orm import sessionmaker | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from core.ops.base_trace_instance import BaseTraceInstance | |||
| from core.ops.entities.config_entity import LangfuseConfig | |||
| @@ -30,8 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( | |||
| ) | |||
| from core.ops.utils import filter_none_values | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| from models.model import EndUser | |||
| from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @@ -113,8 +113,29 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| # through workflow_run_id get all_nodes_execution using repository | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| # Find the app's creator account | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get the app to find its creator | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app = session.query(App).filter(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, tenant_id=trace_info.tenant_id | |||
| session_factory=session_factory, | |||
| user=service_account, | |||
| app_id=trace_info.metadata.get("app_id"), | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| # Get all executions for this workflow run | |||
| @@ -124,23 +145,22 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| for node_execution in workflow_node_executions: | |||
| node_execution_id = node_execution.id | |||
| tenant_id = node_execution.tenant_id | |||
| app_id = node_execution.app_id | |||
| tenant_id = trace_info.tenant_id # Use from trace_info instead | |||
| app_id = trace_info.metadata.get("app_id") # Use from trace_info instead | |||
| node_name = node_execution.title | |||
| node_type = node_execution.node_type | |||
| status = node_execution.status | |||
| if node_type == "llm": | |||
| inputs = ( | |||
| json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} | |||
| ) | |||
| if node_type == NodeType.LLM: | |||
| inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} | |||
| else: | |||
| inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} | |||
| outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} | |||
| inputs = node_execution.inputs if node_execution.inputs else {} | |||
| outputs = node_execution.outputs if node_execution.outputs else {} | |||
| created_at = node_execution.created_at or datetime.now() | |||
| elapsed_time = node_execution.elapsed_time | |||
| finished_at = created_at + timedelta(seconds=elapsed_time) | |||
| metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} | |||
| execution_metadata = node_execution.metadata if node_execution.metadata else {} | |||
| metadata = {str(k): v for k, v in execution_metadata.items()} | |||
| metadata.update( | |||
| { | |||
| "workflow_run_id": trace_info.workflow_run_id, | |||
| @@ -152,7 +172,7 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| "status": status, | |||
| } | |||
| ) | |||
| process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} | |||
| process_data = node_execution.process_data if node_execution.process_data else {} | |||
| model_provider = process_data.get("model_provider", None) | |||
| model_name = process_data.get("model_name", None) | |||
| if model_provider is not None and model_name is not None: | |||
| @@ -1,3 +1,4 @@ | |||
| from collections.abc import Mapping | |||
| from datetime import datetime | |||
| from enum import StrEnum | |||
| from typing import Any, Optional, Union | |||
| @@ -30,8 +31,8 @@ class LangSmithMultiModel(BaseModel): | |||
| class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): | |||
| name: Optional[str] = Field(..., description="Name of the run") | |||
| inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run") | |||
| outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run") | |||
| inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run") | |||
| outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run") | |||
| run_type: LangSmithRunType = Field(..., description="Type of the run") | |||
| start_time: Optional[datetime | str] = Field(None, description="Start time of the run") | |||
| end_time: Optional[datetime | str] = Field(None, description="End time of the run") | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import logging | |||
| import os | |||
| import uuid | |||
| @@ -7,7 +6,7 @@ from typing import Optional, cast | |||
| from langsmith import Client | |||
| from langsmith.schemas import RunBase | |||
| from sqlalchemy.orm import sessionmaker | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from core.ops.base_trace_instance import BaseTraceInstance | |||
| from core.ops.entities.config_entity import LangSmithConfig | |||
| @@ -29,8 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( | |||
| ) | |||
| from core.ops.utils import filter_none_values, generate_dotted_order | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| from models.model import EndUser, MessageFile | |||
| from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @@ -137,8 +138,29 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| # through workflow_run_id get all_nodes_execution using repository | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| # Find the app's creator account | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get the app to find its creator | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app = session.query(App).filter(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") | |||
| session_factory=session_factory, | |||
| user=service_account, | |||
| app_id=trace_info.metadata.get("app_id"), | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| # Get all executions for this workflow run | |||
| @@ -148,27 +170,23 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| for node_execution in workflow_node_executions: | |||
| node_execution_id = node_execution.id | |||
| tenant_id = node_execution.tenant_id | |||
| app_id = node_execution.app_id | |||
| tenant_id = trace_info.tenant_id # Use from trace_info instead | |||
| app_id = trace_info.metadata.get("app_id") # Use from trace_info instead | |||
| node_name = node_execution.title | |||
| node_type = node_execution.node_type | |||
| status = node_execution.status | |||
| if node_type == "llm": | |||
| inputs = ( | |||
| json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} | |||
| ) | |||
| if node_type == NodeType.LLM: | |||
| inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} | |||
| else: | |||
| inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} | |||
| outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} | |||
| inputs = node_execution.inputs if node_execution.inputs else {} | |||
| outputs = node_execution.outputs if node_execution.outputs else {} | |||
| created_at = node_execution.created_at or datetime.now() | |||
| elapsed_time = node_execution.elapsed_time | |||
| finished_at = created_at + timedelta(seconds=elapsed_time) | |||
| execution_metadata = ( | |||
| json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} | |||
| ) | |||
| node_total_tokens = execution_metadata.get("total_tokens", 0) | |||
| metadata = execution_metadata.copy() | |||
| execution_metadata = node_execution.metadata if node_execution.metadata else {} | |||
| node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 | |||
| metadata = {str(key): value for key, value in execution_metadata.items()} | |||
| metadata.update( | |||
| { | |||
| "workflow_run_id": trace_info.workflow_run_id, | |||
| @@ -181,7 +199,7 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| } | |||
| ) | |||
| process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} | |||
| process_data = node_execution.process_data if node_execution.process_data else {} | |||
| if process_data and process_data.get("model_mode") == "chat": | |||
| run_type = LangSmithRunType.llm | |||
| @@ -191,7 +209,7 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| "ls_model_name": process_data.get("model_name", ""), | |||
| } | |||
| ) | |||
| elif node_type == "knowledge-retrieval": | |||
| elif node_type == NodeType.KNOWLEDGE_RETRIEVAL: | |||
| run_type = LangSmithRunType.retriever | |||
| else: | |||
| run_type = LangSmithRunType.tool | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import logging | |||
| import os | |||
| import uuid | |||
| @@ -7,7 +6,7 @@ from typing import Optional, cast | |||
| from opik import Opik, Trace | |||
| from opik.id_helpers import uuid4_to_uuid7 | |||
| from sqlalchemy.orm import sessionmaker | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from core.ops.base_trace_instance import BaseTraceInstance | |||
| from core.ops.entities.config_entity import OpikConfig | |||
| @@ -23,8 +22,10 @@ from core.ops.entities.trace_entity import ( | |||
| WorkflowTraceInfo, | |||
| ) | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| from models.model import EndUser, MessageFile | |||
| from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @@ -150,8 +151,29 @@ class OpikDataTrace(BaseTraceInstance): | |||
| # through workflow_run_id get all_nodes_execution using repository | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| # Find the app's creator account | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get the app to find its creator | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app = session.query(App).filter(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") | |||
| session_factory=session_factory, | |||
| user=service_account, | |||
| app_id=trace_info.metadata.get("app_id"), | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| # Get all executions for this workflow run | |||
| @@ -161,26 +183,22 @@ class OpikDataTrace(BaseTraceInstance): | |||
| for node_execution in workflow_node_executions: | |||
| node_execution_id = node_execution.id | |||
| tenant_id = node_execution.tenant_id | |||
| app_id = node_execution.app_id | |||
| tenant_id = trace_info.tenant_id # Use from trace_info instead | |||
| app_id = trace_info.metadata.get("app_id") # Use from trace_info instead | |||
| node_name = node_execution.title | |||
| node_type = node_execution.node_type | |||
| status = node_execution.status | |||
| if node_type == "llm": | |||
| inputs = ( | |||
| json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} | |||
| ) | |||
| if node_type == NodeType.LLM: | |||
| inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} | |||
| else: | |||
| inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} | |||
| outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} | |||
| inputs = node_execution.inputs if node_execution.inputs else {} | |||
| outputs = node_execution.outputs if node_execution.outputs else {} | |||
| created_at = node_execution.created_at or datetime.now() | |||
| elapsed_time = node_execution.elapsed_time | |||
| finished_at = created_at + timedelta(seconds=elapsed_time) | |||
| execution_metadata = ( | |||
| json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} | |||
| ) | |||
| metadata = execution_metadata.copy() | |||
| execution_metadata = node_execution.metadata if node_execution.metadata else {} | |||
| metadata = {str(k): v for k, v in execution_metadata.items()} | |||
| metadata.update( | |||
| { | |||
| "workflow_run_id": trace_info.workflow_run_id, | |||
| @@ -193,7 +211,7 @@ class OpikDataTrace(BaseTraceInstance): | |||
| } | |||
| ) | |||
| process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} | |||
| process_data = node_execution.process_data if node_execution.process_data else {} | |||
| provider = None | |||
| model = None | |||
| @@ -226,7 +244,7 @@ class OpikDataTrace(BaseTraceInstance): | |||
| parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id | |||
| if not total_tokens: | |||
| total_tokens = execution_metadata.get("total_tokens", 0) | |||
| total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 | |||
| span_data = { | |||
| "trace_id": opik_trace_id, | |||
| @@ -1,3 +1,4 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional, Union | |||
| from pydantic import BaseModel, Field, field_validator | |||
| @@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel): | |||
| class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): | |||
| id: str = Field(..., description="ID of the trace") | |||
| op: str = Field(..., description="Name of the operation") | |||
| inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace") | |||
| outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace") | |||
| inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace") | |||
| outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace") | |||
| attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( | |||
| None, description="Metadata and attributes associated with trace" | |||
| ) | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import logging | |||
| import os | |||
| import uuid | |||
| @@ -7,6 +6,7 @@ from typing import Any, Optional, cast | |||
| import wandb | |||
| import weave | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from core.ops.base_trace_instance import BaseTraceInstance | |||
| from core.ops.entities.config_entity import WeaveConfig | |||
| @@ -22,9 +22,11 @@ from core.ops.entities.trace_entity import ( | |||
| WorkflowTraceInfo, | |||
| ) | |||
| from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| from models.model import EndUser, MessageFile | |||
| from models.workflow import WorkflowNodeExecution | |||
| from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @@ -128,58 +130,57 @@ class WeaveDataTrace(BaseTraceInstance): | |||
| self.start_call(workflow_run, parent_run_id=trace_info.message_id) | |||
| # through workflow_run_id get all_nodes_execution | |||
| workflow_nodes_execution_id_records = ( | |||
| db.session.query(WorkflowNodeExecution.id) | |||
| .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) | |||
| .all() | |||
| # through workflow_run_id get all_nodes_execution using repository | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| # Find the app's creator account | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get the app to find its creator | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app = session.query(App).filter(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=service_account, | |||
| app_id=trace_info.metadata.get("app_id"), | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| for node_execution_id_record in workflow_nodes_execution_id_records: | |||
| node_execution = ( | |||
| db.session.query( | |||
| WorkflowNodeExecution.id, | |||
| WorkflowNodeExecution.tenant_id, | |||
| WorkflowNodeExecution.app_id, | |||
| WorkflowNodeExecution.title, | |||
| WorkflowNodeExecution.node_type, | |||
| WorkflowNodeExecution.status, | |||
| WorkflowNodeExecution.inputs, | |||
| WorkflowNodeExecution.outputs, | |||
| WorkflowNodeExecution.created_at, | |||
| WorkflowNodeExecution.elapsed_time, | |||
| WorkflowNodeExecution.process_data, | |||
| WorkflowNodeExecution.execution_metadata, | |||
| ) | |||
| .filter(WorkflowNodeExecution.id == node_execution_id_record.id) | |||
| .first() | |||
| ) | |||
| if not node_execution: | |||
| continue | |||
| # Get all executions for this workflow run | |||
| workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( | |||
| workflow_run_id=trace_info.workflow_run_id | |||
| ) | |||
| for node_execution in workflow_node_executions: | |||
| node_execution_id = node_execution.id | |||
| tenant_id = node_execution.tenant_id | |||
| app_id = node_execution.app_id | |||
| tenant_id = trace_info.tenant_id # Use from trace_info instead | |||
| app_id = trace_info.metadata.get("app_id") # Use from trace_info instead | |||
| node_name = node_execution.title | |||
| node_type = node_execution.node_type | |||
| status = node_execution.status | |||
| if node_type == "llm": | |||
| inputs = ( | |||
| json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} | |||
| ) | |||
| if node_type == NodeType.LLM: | |||
| inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} | |||
| else: | |||
| inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} | |||
| outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} | |||
| inputs = node_execution.inputs if node_execution.inputs else {} | |||
| outputs = node_execution.outputs if node_execution.outputs else {} | |||
| created_at = node_execution.created_at or datetime.now() | |||
| elapsed_time = node_execution.elapsed_time | |||
| finished_at = created_at + timedelta(seconds=elapsed_time) | |||
| execution_metadata = ( | |||
| json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} | |||
| ) | |||
| node_total_tokens = execution_metadata.get("total_tokens", 0) | |||
| attributes = execution_metadata.copy() | |||
| execution_metadata = node_execution.metadata if node_execution.metadata else {} | |||
| node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 | |||
| attributes = {str(k): v for k, v in execution_metadata.items()} | |||
| attributes.update( | |||
| { | |||
| "workflow_run_id": trace_info.workflow_run_id, | |||
| @@ -192,7 +193,7 @@ class WeaveDataTrace(BaseTraceInstance): | |||
| } | |||
| ) | |||
| process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} | |||
| process_data = node_execution.process_data if node_execution.process_data else {} | |||
| if process_data and process_data.get("model_mode") == "chat": | |||
| attributes.update( | |||
| { | |||
| @@ -19,7 +19,7 @@ from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from models.enums import CreatedByRole | |||
| from models.enums import CreatorUserRole | |||
| from models.model import UploadFile | |||
| logger = logging.getLogger(__name__) | |||
| @@ -116,7 +116,7 @@ class WordExtractor(BaseExtractor): | |||
| extension=str(image_ext), | |||
| mime_type=mime_type or "", | |||
| created_by=self.user_id, | |||
| created_by_role=CreatedByRole.ACCOUNT, | |||
| created_by_role=CreatorUserRole.ACCOUNT, | |||
| created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | |||
| used=True, | |||
| used_by=self.user_id, | |||
| @@ -2,16 +2,29 @@ | |||
| SQLAlchemy implementation of the WorkflowNodeExecutionRepository. | |||
| """ | |||
| import json | |||
| import logging | |||
| from collections.abc import Sequence | |||
| from typing import Optional | |||
| from typing import Optional, Union | |||
| from sqlalchemy import UnaryExpression, asc, delete, desc, select | |||
| from sqlalchemy.engine import Engine | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.workflow.entities.node_execution_entities import ( | |||
| NodeExecution, | |||
| NodeExecutionStatus, | |||
| ) | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | |||
| from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom | |||
| from models import ( | |||
| Account, | |||
| CreatorUserRole, | |||
| EndUser, | |||
| WorkflowNodeExecution, | |||
| WorkflowNodeExecutionStatus, | |||
| WorkflowNodeExecutionTriggeredFrom, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| @@ -23,16 +36,26 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| This implementation supports multi-tenancy by filtering operations based on tenant_id. | |||
| Each method creates its own session, handles the transaction, and commits changes | |||
| to the database. This prevents long-running connections in the workflow core. | |||
| This implementation also includes an in-memory cache for node executions to improve | |||
| performance by reducing database queries. | |||
| """ | |||
| def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None): | |||
| def __init__( | |||
| self, | |||
| session_factory: sessionmaker | Engine, | |||
| user: Union[Account, EndUser], | |||
| app_id: Optional[str], | |||
| triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], | |||
| ): | |||
| """ | |||
| Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context. | |||
| Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. | |||
| Args: | |||
| session_factory: SQLAlchemy sessionmaker or engine for creating sessions | |||
| tenant_id: Tenant ID for multi-tenancy | |||
| app_id: Optional app ID for filtering by application | |||
| user: Account or EndUser object containing tenant_id, user ID, and role information | |||
| app_id: App ID for filtering by application (can be None) | |||
| triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) | |||
| """ | |||
| # If an engine is provided, create a sessionmaker from it | |||
| if isinstance(session_factory, Engine): | |||
| @@ -44,38 +67,155 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" | |||
| ) | |||
| # Extract tenant_id from user | |||
| tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id | |||
| if not tenant_id: | |||
| raise ValueError("User must have a tenant_id or current_tenant_id") | |||
| self._tenant_id = tenant_id | |||
| # Store app context | |||
| self._app_id = app_id | |||
| def save(self, execution: WorkflowNodeExecution) -> None: | |||
| # Extract user context | |||
| self._triggered_from = triggered_from | |||
| self._creator_user_id = user.id | |||
| # Determine user role based on user type | |||
| self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER | |||
| # Initialize in-memory cache for node executions | |||
| # Key: node_execution_id, Value: NodeExecution | |||
| self._node_execution_cache: dict[str, NodeExecution] = {} | |||
| def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution: | |||
| """ | |||
| Save a WorkflowNodeExecution instance and commit changes to the database. | |||
| Convert a database model to a domain model. | |||
| Args: | |||
| execution: The WorkflowNodeExecution instance to save | |||
| db_model: The database model to convert | |||
| Returns: | |||
| The domain model | |||
| """ | |||
| with self._session_factory() as session: | |||
| # Ensure tenant_id is set | |||
| if not execution.tenant_id: | |||
| execution.tenant_id = self._tenant_id | |||
| # Parse JSON fields | |||
| inputs = db_model.inputs_dict | |||
| process_data = db_model.process_data_dict | |||
| outputs = db_model.outputs_dict | |||
| metadata = db_model.execution_metadata_dict | |||
| # Convert status to domain enum | |||
| status = NodeExecutionStatus(db_model.status) | |||
| return NodeExecution( | |||
| id=db_model.id, | |||
| node_execution_id=db_model.node_execution_id, | |||
| workflow_id=db_model.workflow_id, | |||
| workflow_run_id=db_model.workflow_run_id, | |||
| index=db_model.index, | |||
| predecessor_node_id=db_model.predecessor_node_id, | |||
| node_id=db_model.node_id, | |||
| node_type=NodeType(db_model.node_type), | |||
| title=db_model.title, | |||
| inputs=inputs, | |||
| process_data=process_data, | |||
| outputs=outputs, | |||
| status=status, | |||
| error=db_model.error, | |||
| elapsed_time=db_model.elapsed_time, | |||
| metadata=metadata, | |||
| created_at=db_model.created_at, | |||
| finished_at=db_model.finished_at, | |||
| ) | |||
| def _to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution: | |||
| """ | |||
| Convert a domain model to a database model. | |||
| Args: | |||
| domain_model: The domain model to convert | |||
| Returns: | |||
| The database model | |||
| """ | |||
| # Use values from constructor if provided | |||
| if not self._triggered_from: | |||
| raise ValueError("triggered_from is required in repository constructor") | |||
| if not self._creator_user_id: | |||
| raise ValueError("created_by is required in repository constructor") | |||
| if not self._creator_user_role: | |||
| raise ValueError("created_by_role is required in repository constructor") | |||
| db_model = WorkflowNodeExecution() | |||
| db_model.id = domain_model.id | |||
| db_model.tenant_id = self._tenant_id | |||
| if self._app_id is not None: | |||
| db_model.app_id = self._app_id | |||
| db_model.workflow_id = domain_model.workflow_id | |||
| db_model.triggered_from = self._triggered_from | |||
| db_model.workflow_run_id = domain_model.workflow_run_id | |||
| db_model.index = domain_model.index | |||
| db_model.predecessor_node_id = domain_model.predecessor_node_id | |||
| db_model.node_execution_id = domain_model.node_execution_id | |||
| db_model.node_id = domain_model.node_id | |||
| db_model.node_type = domain_model.node_type | |||
| db_model.title = domain_model.title | |||
| db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None | |||
| db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None | |||
| db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None | |||
| db_model.status = domain_model.status | |||
| db_model.error = domain_model.error | |||
| db_model.elapsed_time = domain_model.elapsed_time | |||
| db_model.execution_metadata = json.dumps(domain_model.metadata) if domain_model.metadata else None | |||
| db_model.created_at = domain_model.created_at | |||
| db_model.created_by_role = self._creator_user_role | |||
| db_model.created_by = self._creator_user_id | |||
| db_model.finished_at = domain_model.finished_at | |||
| return db_model | |||
| def save(self, execution: NodeExecution) -> None: | |||
| """ | |||
| Save or update a NodeExecution instance and commit changes to the database. | |||
| # Set app_id if provided and not already set | |||
| if self._app_id and not execution.app_id: | |||
| execution.app_id = self._app_id | |||
| This method handles both creating new records and updating existing ones. | |||
| It determines whether to create or update based on whether the record | |||
| already exists in the database. It also updates the in-memory cache. | |||
| session.add(execution) | |||
| Args: | |||
| execution: The NodeExecution instance to save or update | |||
| """ | |||
| with self._session_factory() as session: | |||
| # Convert domain model to database model using instance attributes | |||
| db_model = self._to_db_model(execution) | |||
| # Use merge which will handle both insert and update | |||
| session.merge(db_model) | |||
| session.commit() | |||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: | |||
| # Update the cache if node_execution_id is present | |||
| if execution.node_execution_id: | |||
| logger.debug(f"Updating cache for node_execution_id: {execution.node_execution_id}") | |||
| self._node_execution_cache[execution.node_execution_id] = execution | |||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]: | |||
| """ | |||
| Retrieve a WorkflowNodeExecution by its node_execution_id. | |||
| Retrieve a NodeExecution by its node_execution_id. | |||
| First checks the in-memory cache, and if not found, queries the database. | |||
| If found in the database, adds it to the cache for future lookups. | |||
| Args: | |||
| node_execution_id: The node execution ID | |||
| Returns: | |||
| The WorkflowNodeExecution instance if found, None otherwise | |||
| The NodeExecution instance if found, None otherwise | |||
| """ | |||
| # First check the cache | |||
| if node_execution_id in self._node_execution_cache: | |||
| logger.debug(f"Cache hit for node_execution_id: {node_execution_id}") | |||
| return self._node_execution_cache[node_execution_id] | |||
| # If not in cache, query the database | |||
| logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database") | |||
| with self._session_factory() as session: | |||
| stmt = select(WorkflowNodeExecution).where( | |||
| WorkflowNodeExecution.node_execution_id == node_execution_id, | |||
| @@ -85,15 +225,63 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| if self._app_id: | |||
| stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) | |||
| return session.scalar(stmt) | |||
| db_model = session.scalar(stmt) | |||
| if db_model: | |||
| # Convert to domain model | |||
| domain_model = self._to_domain_model(db_model) | |||
| # Add to cache | |||
| self._node_execution_cache[node_execution_id] = domain_model | |||
| return domain_model | |||
| return None | |||
| def get_by_workflow_run( | |||
| self, | |||
| workflow_run_id: str, | |||
| order_config: Optional[OrderConfig] = None, | |||
| ) -> Sequence[NodeExecution]: | |||
| """ | |||
| Retrieve all NodeExecution instances for a specific workflow run. | |||
| This method always queries the database to ensure complete and ordered results, | |||
| but updates the cache with any retrieved executions. | |||
| Args: | |||
| workflow_run_id: The workflow run ID | |||
| order_config: Optional configuration for ordering results | |||
| order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) | |||
| order_config.order_direction: Direction to order ("asc" or "desc") | |||
| Returns: | |||
| A list of NodeExecution instances | |||
| """ | |||
| # Get the raw database models using the new method | |||
| db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config) | |||
| # Convert database models to domain models and update cache | |||
| domain_models = [] | |||
| for model in db_models: | |||
| domain_model = self._to_domain_model(model) | |||
| # Update cache if node_execution_id is present | |||
| if domain_model.node_execution_id: | |||
| self._node_execution_cache[domain_model.node_execution_id] = domain_model | |||
| domain_models.append(domain_model) | |||
| return domain_models | |||
| def get_db_models_by_workflow_run( | |||
| self, | |||
| workflow_run_id: str, | |||
| order_config: Optional[OrderConfig] = None, | |||
| ) -> Sequence[WorkflowNodeExecution]: | |||
| """ | |||
| Retrieve all WorkflowNodeExecution instances for a specific workflow run. | |||
| Retrieve all WorkflowNodeExecution database models for a specific workflow run. | |||
| This method is similar to get_by_workflow_run but returns the raw database models | |||
| instead of converting them to domain models. This can be useful when direct access | |||
| to database model properties is needed. | |||
| Args: | |||
| workflow_run_id: The workflow run ID | |||
| @@ -102,7 +290,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| order_config.order_direction: Direction to order ("asc" or "desc") | |||
| Returns: | |||
| A list of WorkflowNodeExecution instances | |||
| A list of WorkflowNodeExecution database models | |||
| """ | |||
| with self._session_factory() as session: | |||
| stmt = select(WorkflowNodeExecution).where( | |||
| @@ -129,17 +317,25 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| if order_columns: | |||
| stmt = stmt.order_by(*order_columns) | |||
| return session.scalars(stmt).all() | |||
| db_models = session.scalars(stmt).all() | |||
| # Note: We don't update the cache here since we're returning raw DB models | |||
| # and not converting to domain models | |||
| return db_models | |||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: | |||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]: | |||
| """ | |||
| Retrieve all running WorkflowNodeExecution instances for a specific workflow run. | |||
| Retrieve all running NodeExecution instances for a specific workflow run. | |||
| This method queries the database directly and updates the cache with any | |||
| retrieved executions that have a node_execution_id. | |||
| Args: | |||
| workflow_run_id: The workflow run ID | |||
| Returns: | |||
| A list of running WorkflowNodeExecution instances | |||
| A list of running NodeExecution instances | |||
| """ | |||
| with self._session_factory() as session: | |||
| stmt = select(WorkflowNodeExecution).where( | |||
| @@ -152,26 +348,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| if self._app_id: | |||
| stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) | |||
| return session.scalars(stmt).all() | |||
| db_models = session.scalars(stmt).all() | |||
| domain_models = [] | |||
| def update(self, execution: WorkflowNodeExecution) -> None: | |||
| """ | |||
| Update an existing WorkflowNodeExecution instance and commit changes to the database. | |||
| for model in db_models: | |||
| domain_model = self._to_domain_model(model) | |||
| # Update cache if node_execution_id is present | |||
| if domain_model.node_execution_id: | |||
| self._node_execution_cache[domain_model.node_execution_id] = domain_model | |||
| domain_models.append(domain_model) | |||
| Args: | |||
| execution: The WorkflowNodeExecution instance to update | |||
| """ | |||
| with self._session_factory() as session: | |||
| # Ensure tenant_id is set | |||
| if not execution.tenant_id: | |||
| execution.tenant_id = self._tenant_id | |||
| # Set app_id if provided and not already set | |||
| if self._app_id and not execution.app_id: | |||
| execution.app_id = self._app_id | |||
| session.merge(execution) | |||
| session.commit() | |||
| return domain_models | |||
| def clear(self) -> None: | |||
| """ | |||
| @@ -179,6 +366,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| This method deletes all WorkflowNodeExecution records that match the tenant_id | |||
| and app_id (if provided) associated with this repository instance. | |||
| It also clears the in-memory cache. | |||
| """ | |||
| with self._session_factory() as session: | |||
| stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) | |||
| @@ -194,3 +382,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" | |||
| + (f" and app {self._app_id}" if self._app_id else "") | |||
| ) | |||
| # Clear the in-memory cache | |||
| self._node_execution_cache.clear() | |||
| logger.info("Cleared in-memory node execution cache") | |||
| @@ -32,7 +32,7 @@ from core.tools.errors import ( | |||
| from core.tools.utils.message_transformer import ToolFileMessageTransformer | |||
| from core.tools.workflow_as_tool.tool import WorkflowTool | |||
| from extensions.ext_database import db | |||
| from models.enums import CreatedByRole | |||
| from models.enums import CreatorUserRole | |||
| from models.model import Message, MessageFile | |||
| @@ -339,9 +339,9 @@ class ToolEngine: | |||
| url=message.url, | |||
| upload_file_id=tool_file_id, | |||
| created_by_role=( | |||
| CreatedByRole.ACCOUNT | |||
| CreatorUserRole.ACCOUNT | |||
| if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | |||
| else CreatedByRole.END_USER | |||
| else CreatorUserRole.END_USER | |||
| ), | |||
| created_by=user_id, | |||
| ) | |||
| @@ -0,0 +1,98 @@ | |||
| """ | |||
| Domain entities for workflow node execution. | |||
| This module contains the domain model for workflow node execution, which is used | |||
| by the core workflow module. These models are independent of the storage mechanism | |||
| and don't contain implementation details like tenant_id, app_id, etc. | |||
| """ | |||
| from collections.abc import Mapping | |||
| from datetime import datetime | |||
| from enum import StrEnum | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, Field | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| class NodeExecutionStatus(StrEnum): | |||
| """ | |||
| Node Execution Status Enum. | |||
| """ | |||
| RUNNING = "running" | |||
| SUCCEEDED = "succeeded" | |||
| FAILED = "failed" | |||
| EXCEPTION = "exception" | |||
| RETRY = "retry" | |||
| class NodeExecution(BaseModel): | |||
| """ | |||
| Domain model for workflow node execution. | |||
| This model represents the core business entity of a node execution, | |||
| without implementation details like tenant_id, app_id, etc. | |||
| Note: User/context-specific fields (triggered_from, created_by, created_by_role) | |||
| have been moved to the repository implementation to keep the domain model clean. | |||
| These fields are still accepted in the constructor for backward compatibility, | |||
| but they are not stored in the model. | |||
| """ | |||
| # Core identification fields | |||
| id: str # Unique identifier for this execution record | |||
| node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing | |||
| workflow_id: str # ID of the workflow this node belongs to | |||
| workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging) | |||
| # Execution positioning and flow | |||
| index: int # Sequence number for ordering in trace visualization | |||
| predecessor_node_id: Optional[str] = None # ID of the node that executed before this one | |||
| node_id: str # ID of the node being executed | |||
| node_type: NodeType # Type of node (e.g., start, llm, knowledge) | |||
| title: str # Display title of the node | |||
| # Execution data | |||
| inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node | |||
| process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data | |||
| outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node | |||
| # Execution state | |||
| status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status | |||
| error: Optional[str] = None # Error message if execution failed | |||
| elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds | |||
| # Additional metadata | |||
| metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) | |||
| # Timing information | |||
| created_at: datetime # When execution started | |||
| finished_at: Optional[datetime] = None # When execution completed | |||
| def update_from_mapping( | |||
| self, | |||
| inputs: Optional[Mapping[str, Any]] = None, | |||
| process_data: Optional[Mapping[str, Any]] = None, | |||
| outputs: Optional[Mapping[str, Any]] = None, | |||
| metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None, | |||
| ) -> None: | |||
| """ | |||
| Update the model from mappings. | |||
| Args: | |||
| inputs: The inputs to update | |||
| process_data: The process data to update | |||
| outputs: The outputs to update | |||
| metadata: The metadata to update | |||
| """ | |||
| if inputs is not None: | |||
| self.inputs = dict(inputs) | |||
| if process_data is not None: | |||
| self.process_data = dict(process_data) | |||
| if outputs is not None: | |||
| self.outputs = dict(outputs) | |||
| if metadata is not None: | |||
| self.metadata = dict(metadata) | |||
| @@ -2,12 +2,12 @@ from collections.abc import Sequence | |||
| from dataclasses import dataclass | |||
| from typing import Literal, Optional, Protocol | |||
| from models.workflow import WorkflowNodeExecution | |||
| from core.workflow.entities.node_execution_entities import NodeExecution | |||
| @dataclass | |||
| class OrderConfig: | |||
| """Configuration for ordering WorkflowNodeExecution instances.""" | |||
| """Configuration for ordering NodeExecution instances.""" | |||
| order_by: list[str] | |||
| order_direction: Optional[Literal["asc", "desc"]] = None | |||
| @@ -15,10 +15,10 @@ class OrderConfig: | |||
| class WorkflowNodeExecutionRepository(Protocol): | |||
| """ | |||
| Repository interface for WorkflowNodeExecution. | |||
| Repository interface for NodeExecution. | |||
| This interface defines the contract for accessing and manipulating | |||
| WorkflowNodeExecution data, regardless of the underlying storage mechanism. | |||
| NodeExecution data, regardless of the underlying storage mechanism. | |||
| Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), | |||
| and trigger sources (triggered_from) should be handled at the implementation level, not in | |||
| @@ -26,24 +26,28 @@ class WorkflowNodeExecutionRepository(Protocol): | |||
| application domains or deployment scenarios. | |||
| """ | |||
| def save(self, execution: WorkflowNodeExecution) -> None: | |||
| def save(self, execution: NodeExecution) -> None: | |||
| """ | |||
| Save a WorkflowNodeExecution instance. | |||
| Save or update a NodeExecution instance. | |||
| This method handles both creating new records and updating existing ones. | |||
| The implementation should determine whether to create or update based on | |||
| the execution's ID or other identifying fields. | |||
| Args: | |||
| execution: The WorkflowNodeExecution instance to save | |||
| execution: The NodeExecution instance to save or update | |||
| """ | |||
| ... | |||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: | |||
| def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]: | |||
| """ | |||
| Retrieve a WorkflowNodeExecution by its node_execution_id. | |||
| Retrieve a NodeExecution by its node_execution_id. | |||
| Args: | |||
| node_execution_id: The node execution ID | |||
| Returns: | |||
| The WorkflowNodeExecution instance if found, None otherwise | |||
| The NodeExecution instance if found, None otherwise | |||
| """ | |||
| ... | |||
| @@ -51,9 +55,9 @@ class WorkflowNodeExecutionRepository(Protocol): | |||
| self, | |||
| workflow_run_id: str, | |||
| order_config: Optional[OrderConfig] = None, | |||
| ) -> Sequence[WorkflowNodeExecution]: | |||
| ) -> Sequence[NodeExecution]: | |||
| """ | |||
| Retrieve all WorkflowNodeExecution instances for a specific workflow run. | |||
| Retrieve all NodeExecution instances for a specific workflow run. | |||
| Args: | |||
| workflow_run_id: The workflow run ID | |||
| @@ -62,34 +66,25 @@ class WorkflowNodeExecutionRepository(Protocol): | |||
| order_config.order_direction: Direction to order ("asc" or "desc") | |||
| Returns: | |||
| A list of WorkflowNodeExecution instances | |||
| A list of NodeExecution instances | |||
| """ | |||
| ... | |||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: | |||
| def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]: | |||
| """ | |||
| Retrieve all running WorkflowNodeExecution instances for a specific workflow run. | |||
| Retrieve all running NodeExecution instances for a specific workflow run. | |||
| Args: | |||
| workflow_run_id: The workflow run ID | |||
| Returns: | |||
| A list of running WorkflowNodeExecution instances | |||
| """ | |||
| ... | |||
| def update(self, execution: WorkflowNodeExecution) -> None: | |||
| """ | |||
| Update an existing WorkflowNodeExecution instance. | |||
| Args: | |||
| execution: The WorkflowNodeExecution instance to update | |||
| A list of running NodeExecution instances | |||
| """ | |||
| ... | |||
| def clear(self) -> None: | |||
| """ | |||
| Clear all WorkflowNodeExecution records based on implementation-specific criteria. | |||
| Clear all NodeExecution records based on implementation-specific criteria. | |||
| This method is intended to be used for bulk deletion operations, such as removing | |||
| all records associated with a specific app_id and tenant_id in multi-tenant implementations. | |||
| @@ -58,7 +58,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow | |||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.enums import CreatedByRole | |||
| from models.enums import CreatorUserRole | |||
| from models.model import EndUser | |||
| from models.workflow import ( | |||
| Workflow, | |||
| @@ -94,11 +94,11 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if isinstance(user, EndUser): | |||
| self._user_id = user.id | |||
| user_session_id = user.session_id | |||
| self._created_by_role = CreatedByRole.END_USER | |||
| self._created_by_role = CreatorUserRole.END_USER | |||
| elif isinstance(user, Account): | |||
| self._user_id = user.id | |||
| user_session_id = user.id | |||
| self._created_by_role = CreatedByRole.ACCOUNT | |||
| self._created_by_role = CreatorUserRole.ACCOUNT | |||
| else: | |||
| raise ValueError(f"Invalid user type: {type(user)}") | |||
| @@ -46,26 +46,28 @@ from core.app.entities.task_entities import ( | |||
| ) | |||
| from core.app.task_pipeline.exc import WorkflowRunNotFoundError | |||
| from core.file import FILE_MODEL_IDENTITY, File | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.entities.trace_entity import TraceTaskName | |||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.node_execution_entities import ( | |||
| NodeExecution, | |||
| NodeExecutionStatus, | |||
| ) | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from models.account import Account | |||
| from models.enums import CreatedByRole, WorkflowRunTriggeredFrom | |||
| from models.model import EndUser | |||
| from models.workflow import ( | |||
| from models import ( | |||
| Account, | |||
| CreatorUserRole, | |||
| EndUser, | |||
| Workflow, | |||
| WorkflowNodeExecution, | |||
| WorkflowNodeExecutionStatus, | |||
| WorkflowNodeExecutionTriggeredFrom, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| WorkflowRunTriggeredFrom, | |||
| ) | |||
| @@ -78,7 +80,6 @@ class WorkflowCycleManager: | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| ) -> None: | |||
| self._workflow_run: WorkflowRun | None = None | |||
| self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} | |||
| self._application_generate_entity = application_generate_entity | |||
| self._workflow_system_variables = workflow_system_variables | |||
| self._workflow_node_execution_repository = workflow_node_execution_repository | |||
| @@ -89,7 +90,7 @@ class WorkflowCycleManager: | |||
| session: Session, | |||
| workflow_id: str, | |||
| user_id: str, | |||
| created_by_role: CreatedByRole, | |||
| created_by_role: CreatorUserRole, | |||
| ) -> WorkflowRun: | |||
| workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) | |||
| workflow = session.scalar(workflow_stmt) | |||
| @@ -258,21 +259,22 @@ class WorkflowCycleManager: | |||
| workflow_run.exceptions_count = exceptions_count | |||
| # Use the instance repository to find running executions for a workflow run | |||
| running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions( | |||
| running_domain_executions = self._workflow_node_execution_repository.get_running_executions( | |||
| workflow_run_id=workflow_run.id | |||
| ) | |||
| # Update the cache with the retrieved executions | |||
| for execution in running_workflow_node_executions: | |||
| if execution.node_execution_id: | |||
| self._workflow_node_executions[execution.node_execution_id] = execution | |||
| # Update the domain models | |||
| now = datetime.now(UTC).replace(tzinfo=None) | |||
| for domain_execution in running_domain_executions: | |||
| if domain_execution.node_execution_id: | |||
| # Update the domain model | |||
| domain_execution.status = NodeExecutionStatus.FAILED | |||
| domain_execution.error = error | |||
| domain_execution.finished_at = now | |||
| domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds() | |||
| for workflow_node_execution in running_workflow_node_executions: | |||
| now = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | |||
| workflow_node_execution.error = error | |||
| workflow_node_execution.finished_at = now | |||
| workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() | |||
| # Update the repository with the domain model | |||
| self._workflow_node_execution_repository.save(domain_execution) | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| @@ -286,63 +288,67 @@ class WorkflowCycleManager: | |||
| return workflow_run | |||
| def _handle_node_execution_start( | |||
| self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent | |||
| ) -> WorkflowNodeExecution: | |||
| workflow_node_execution = WorkflowNodeExecution() | |||
| workflow_node_execution.id = str(uuid4()) | |||
| workflow_node_execution.tenant_id = workflow_run.tenant_id | |||
| workflow_node_execution.app_id = workflow_run.app_id | |||
| workflow_node_execution.workflow_id = workflow_run.workflow_id | |||
| workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value | |||
| workflow_node_execution.workflow_run_id = workflow_run.id | |||
| workflow_node_execution.predecessor_node_id = event.predecessor_node_id | |||
| workflow_node_execution.index = event.node_run_index | |||
| workflow_node_execution.node_execution_id = event.node_execution_id | |||
| workflow_node_execution.node_id = event.node_id | |||
| workflow_node_execution.node_type = event.node_type.value | |||
| workflow_node_execution.title = event.node_data.title | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value | |||
| workflow_node_execution.created_by_role = workflow_run.created_by_role | |||
| workflow_node_execution.created_by = workflow_run.created_by | |||
| workflow_node_execution.execution_metadata = json.dumps( | |||
| { | |||
| NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | |||
| NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, | |||
| NodeRunMetadataKey.LOOP_ID: event.in_loop_id, | |||
| } | |||
| def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution: | |||
| # Create a domain model | |||
| created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| metadata = { | |||
| NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | |||
| NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, | |||
| NodeRunMetadataKey.LOOP_ID: event.in_loop_id, | |||
| } | |||
| domain_execution = NodeExecution( | |||
| id=str(uuid4()), | |||
| workflow_id=workflow_run.workflow_id, | |||
| workflow_run_id=workflow_run.id, | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| index=event.node_run_index, | |||
| node_execution_id=event.node_execution_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| title=event.node_data.title, | |||
| status=NodeExecutionStatus.RUNNING, | |||
| metadata=metadata, | |||
| created_at=created_at, | |||
| ) | |||
| workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| # Use the instance repository to save the workflow node execution | |||
| self._workflow_node_execution_repository.save(workflow_node_execution) | |||
| # Use the instance repository to save the domain model | |||
| self._workflow_node_execution_repository.save(domain_execution) | |||
| return domain_execution | |||
| self._workflow_node_executions[event.node_execution_id] = workflow_node_execution | |||
| return workflow_node_execution | |||
| def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution: | |||
| # Get the domain model from repository | |||
| domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) | |||
| if not domain_execution: | |||
| raise ValueError(f"Domain node execution not found: {event.node_execution_id}") | |||
| def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: | |||
| workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) | |||
| # Process data | |||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | |||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | |||
| execution_metadata_dict = dict(event.execution_metadata or {}) | |||
| execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None | |||
| # Convert metadata keys to strings | |||
| execution_metadata_dict = {} | |||
| if event.execution_metadata: | |||
| for key, value in event.execution_metadata.items(): | |||
| execution_metadata_dict[key] = value | |||
| finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| elapsed_time = (finished_at - event.start_at).total_seconds() | |||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||
| # Update domain model | |||
| domain_execution.status = NodeExecutionStatus.SUCCEEDED | |||
| domain_execution.update_from_mapping( | |||
| inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict | |||
| ) | |||
| domain_execution.finished_at = finished_at | |||
| domain_execution.elapsed_time = elapsed_time | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value | |||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||
| workflow_node_execution.process_data = json.dumps(process_data) if process_data else None | |||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||
| workflow_node_execution.execution_metadata = execution_metadata | |||
| workflow_node_execution.finished_at = finished_at | |||
| workflow_node_execution.elapsed_time = elapsed_time | |||
| # Update the repository with the domain model | |||
| self._workflow_node_execution_repository.save(domain_execution) | |||
| # Use the instance repository to update the workflow node execution | |||
| self._workflow_node_execution_repository.update(workflow_node_execution) | |||
| return workflow_node_execution | |||
| return domain_execution | |||
| def _handle_workflow_node_execution_failed( | |||
| self, | |||
| @@ -351,43 +357,52 @@ class WorkflowCycleManager: | |||
| | QueueNodeInIterationFailedEvent | |||
| | QueueNodeInLoopFailedEvent | |||
| | QueueNodeExceptionEvent, | |||
| ) -> WorkflowNodeExecution: | |||
| ) -> NodeExecution: | |||
| """ | |||
| Workflow node execution failed | |||
| :param event: queue node failed event | |||
| :return: | |||
| """ | |||
| workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) | |||
| # Get the domain model from repository | |||
| domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) | |||
| if not domain_execution: | |||
| raise ValueError(f"Domain node execution not found: {event.node_execution_id}") | |||
| # Process data | |||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | |||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | |||
| # Convert metadata keys to strings | |||
| execution_metadata_dict = {} | |||
| if event.execution_metadata: | |||
| for key, value in event.execution_metadata.items(): | |||
| execution_metadata_dict[key] = value | |||
| finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| elapsed_time = (finished_at - event.start_at).total_seconds() | |||
| execution_metadata = ( | |||
| json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None | |||
| ) | |||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||
| workflow_node_execution.status = ( | |||
| WorkflowNodeExecutionStatus.FAILED.value | |||
| # Update domain model | |||
| domain_execution.status = ( | |||
| NodeExecutionStatus.FAILED | |||
| if not isinstance(event, QueueNodeExceptionEvent) | |||
| else WorkflowNodeExecutionStatus.EXCEPTION.value | |||
| else NodeExecutionStatus.EXCEPTION | |||
| ) | |||
| workflow_node_execution.error = event.error | |||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||
| workflow_node_execution.process_data = json.dumps(process_data) if process_data else None | |||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||
| workflow_node_execution.finished_at = finished_at | |||
| workflow_node_execution.elapsed_time = elapsed_time | |||
| workflow_node_execution.execution_metadata = execution_metadata | |||
| domain_execution.error = event.error | |||
| domain_execution.update_from_mapping( | |||
| inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict | |||
| ) | |||
| domain_execution.finished_at = finished_at | |||
| domain_execution.elapsed_time = elapsed_time | |||
| self._workflow_node_execution_repository.update(workflow_node_execution) | |||
| # Update the repository with the domain model | |||
| self._workflow_node_execution_repository.save(domain_execution) | |||
| return workflow_node_execution | |||
| return domain_execution | |||
| def _handle_workflow_node_execution_retried( | |||
| self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent | |||
| ) -> WorkflowNodeExecution: | |||
| ) -> NodeExecution: | |||
| """ | |||
| Workflow node execution failed | |||
| :param workflow_run: workflow run | |||
| @@ -399,47 +414,47 @@ class WorkflowCycleManager: | |||
| elapsed_time = (finished_at - created_at).total_seconds() | |||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | |||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | |||
| # Convert metadata keys to strings | |||
| origin_metadata = { | |||
| NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, | |||
| NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | |||
| NodeRunMetadataKey.LOOP_ID: event.in_loop_id, | |||
| } | |||
| merged_metadata = ( | |||
| {**jsonable_encoder(event.execution_metadata), **origin_metadata} | |||
| if event.execution_metadata is not None | |||
| else origin_metadata | |||
| # Convert execution metadata keys to strings | |||
| execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {} | |||
| if event.execution_metadata: | |||
| for key, value in event.execution_metadata.items(): | |||
| execution_metadata_dict[key] = value | |||
| merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata | |||
| # Create a domain model | |||
| domain_execution = NodeExecution( | |||
| id=str(uuid4()), | |||
| workflow_id=workflow_run.workflow_id, | |||
| workflow_run_id=workflow_run.id, | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| node_execution_id=event.node_execution_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| title=event.node_data.title, | |||
| status=NodeExecutionStatus.RETRY, | |||
| created_at=created_at, | |||
| finished_at=finished_at, | |||
| elapsed_time=elapsed_time, | |||
| error=event.error, | |||
| index=event.node_run_index, | |||
| ) | |||
| execution_metadata = json.dumps(merged_metadata) | |||
| workflow_node_execution = WorkflowNodeExecution() | |||
| workflow_node_execution.id = str(uuid4()) | |||
| workflow_node_execution.tenant_id = workflow_run.tenant_id | |||
| workflow_node_execution.app_id = workflow_run.app_id | |||
| workflow_node_execution.workflow_id = workflow_run.workflow_id | |||
| workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value | |||
| workflow_node_execution.workflow_run_id = workflow_run.id | |||
| workflow_node_execution.predecessor_node_id = event.predecessor_node_id | |||
| workflow_node_execution.node_execution_id = event.node_execution_id | |||
| workflow_node_execution.node_id = event.node_id | |||
| workflow_node_execution.node_type = event.node_type.value | |||
| workflow_node_execution.title = event.node_data.title | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value | |||
| workflow_node_execution.created_by_role = workflow_run.created_by_role | |||
| workflow_node_execution.created_by = workflow_run.created_by | |||
| workflow_node_execution.created_at = created_at | |||
| workflow_node_execution.finished_at = finished_at | |||
| workflow_node_execution.elapsed_time = elapsed_time | |||
| workflow_node_execution.error = event.error | |||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||
| workflow_node_execution.execution_metadata = execution_metadata | |||
| workflow_node_execution.index = event.node_run_index | |||
| # Use the instance repository to save the workflow node execution | |||
| self._workflow_node_execution_repository.save(workflow_node_execution) | |||
| self._workflow_node_executions[event.node_execution_id] = workflow_node_execution | |||
| return workflow_node_execution | |||
| # Update with mappings | |||
| domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata) | |||
| # Use the instance repository to save the domain model | |||
| self._workflow_node_execution_repository.save(domain_execution) | |||
| return domain_execution | |||
| def _workflow_start_to_stream_response( | |||
| self, | |||
| @@ -469,7 +484,7 @@ class WorkflowCycleManager: | |||
| workflow_run: WorkflowRun, | |||
| ) -> WorkflowFinishStreamResponse: | |||
| created_by = None | |||
| if workflow_run.created_by_role == CreatedByRole.ACCOUNT: | |||
| if workflow_run.created_by_role == CreatorUserRole.ACCOUNT: | |||
| stmt = select(Account).where(Account.id == workflow_run.created_by) | |||
| account = session.scalar(stmt) | |||
| if account: | |||
| @@ -478,7 +493,7 @@ class WorkflowCycleManager: | |||
| "name": account.name, | |||
| "email": account.email, | |||
| } | |||
| elif workflow_run.created_by_role == CreatedByRole.END_USER: | |||
| elif workflow_run.created_by_role == CreatorUserRole.END_USER: | |||
| stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) | |||
| end_user = session.scalar(stmt) | |||
| if end_user: | |||
| @@ -515,9 +530,9 @@ class WorkflowCycleManager: | |||
| *, | |||
| event: QueueNodeStartedEvent, | |||
| task_id: str, | |||
| workflow_node_execution: WorkflowNodeExecution, | |||
| workflow_node_execution: NodeExecution, | |||
| ) -> Optional[NodeStartStreamResponse]: | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: | |||
| return None | |||
| if not workflow_node_execution.workflow_run_id: | |||
| return None | |||
| @@ -532,7 +547,7 @@ class WorkflowCycleManager: | |||
| title=workflow_node_execution.title, | |||
| index=workflow_node_execution.index, | |||
| predecessor_node_id=workflow_node_execution.predecessor_node_id, | |||
| inputs=workflow_node_execution.inputs_dict, | |||
| inputs=workflow_node_execution.inputs, | |||
| created_at=int(workflow_node_execution.created_at.timestamp()), | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| @@ -565,9 +580,9 @@ class WorkflowCycleManager: | |||
| | QueueNodeInLoopFailedEvent | |||
| | QueueNodeExceptionEvent, | |||
| task_id: str, | |||
| workflow_node_execution: WorkflowNodeExecution, | |||
| workflow_node_execution: NodeExecution, | |||
| ) -> Optional[NodeFinishStreamResponse]: | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: | |||
| return None | |||
| if not workflow_node_execution.workflow_run_id: | |||
| return None | |||
| @@ -584,16 +599,16 @@ class WorkflowCycleManager: | |||
| index=workflow_node_execution.index, | |||
| title=workflow_node_execution.title, | |||
| predecessor_node_id=workflow_node_execution.predecessor_node_id, | |||
| inputs=workflow_node_execution.inputs_dict, | |||
| process_data=workflow_node_execution.process_data_dict, | |||
| outputs=workflow_node_execution.outputs_dict, | |||
| inputs=workflow_node_execution.inputs, | |||
| process_data=workflow_node_execution.process_data, | |||
| outputs=workflow_node_execution.outputs, | |||
| status=workflow_node_execution.status, | |||
| error=workflow_node_execution.error, | |||
| elapsed_time=workflow_node_execution.elapsed_time, | |||
| execution_metadata=workflow_node_execution.execution_metadata_dict, | |||
| execution_metadata=workflow_node_execution.metadata, | |||
| created_at=int(workflow_node_execution.created_at.timestamp()), | |||
| finished_at=int(workflow_node_execution.finished_at.timestamp()), | |||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), | |||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| @@ -608,9 +623,9 @@ class WorkflowCycleManager: | |||
| *, | |||
| event: QueueNodeRetryEvent, | |||
| task_id: str, | |||
| workflow_node_execution: WorkflowNodeExecution, | |||
| workflow_node_execution: NodeExecution, | |||
| ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: | |||
| return None | |||
| if not workflow_node_execution.workflow_run_id: | |||
| return None | |||
| @@ -627,16 +642,16 @@ class WorkflowCycleManager: | |||
| index=workflow_node_execution.index, | |||
| title=workflow_node_execution.title, | |||
| predecessor_node_id=workflow_node_execution.predecessor_node_id, | |||
| inputs=workflow_node_execution.inputs_dict, | |||
| process_data=workflow_node_execution.process_data_dict, | |||
| outputs=workflow_node_execution.outputs_dict, | |||
| inputs=workflow_node_execution.inputs, | |||
| process_data=workflow_node_execution.process_data, | |||
| outputs=workflow_node_execution.outputs, | |||
| status=workflow_node_execution.status, | |||
| error=workflow_node_execution.error, | |||
| elapsed_time=workflow_node_execution.elapsed_time, | |||
| execution_metadata=workflow_node_execution.execution_metadata_dict, | |||
| execution_metadata=workflow_node_execution.metadata, | |||
| created_at=int(workflow_node_execution.created_at.timestamp()), | |||
| finished_at=int(workflow_node_execution.finished_at.timestamp()), | |||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), | |||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| @@ -908,23 +923,6 @@ class WorkflowCycleManager: | |||
| return workflow_run | |||
| def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: | |||
| # First check the cache for performance | |||
| if node_execution_id in self._workflow_node_executions: | |||
| cached_execution = self._workflow_node_executions[node_execution_id] | |||
| # No need to merge with session since expire_on_commit=False | |||
| return cached_execution | |||
| # If not in cache, use the instance repository to get by node_execution_id | |||
| execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id) | |||
| if not execution: | |||
| raise ValueError(f"Workflow node execution not found: {node_execution_id}") | |||
| # Update cache | |||
| self._workflow_node_executions[node_execution_id] = execution | |||
| return execution | |||
| def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: | |||
| """ | |||
| Handle agent log | |||
| @@ -27,7 +27,7 @@ from .dataset import ( | |||
| Whitelist, | |||
| ) | |||
| from .engine import db | |||
| from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom | |||
| from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom | |||
| from .model import ( | |||
| ApiRequest, | |||
| ApiToken, | |||
| @@ -112,7 +112,7 @@ __all__ = [ | |||
| "CeleryTaskSet", | |||
| "Conversation", | |||
| "ConversationVariable", | |||
| "CreatedByRole", | |||
| "CreatorUserRole", | |||
| "DataSourceApiKeyAuthBinding", | |||
| "DataSourceOauthBinding", | |||
| "Dataset", | |||
| @@ -1,7 +1,7 @@ | |||
| from enum import StrEnum | |||
| class CreatedByRole(StrEnum): | |||
| class CreatorUserRole(StrEnum): | |||
| ACCOUNT = "account" | |||
| END_USER = "end_user" | |||
| @@ -29,7 +29,7 @@ from libs.helper import generate_string | |||
| from .account import Account, Tenant | |||
| from .base import Base | |||
| from .engine import db | |||
| from .enums import CreatedByRole | |||
| from .enums import CreatorUserRole | |||
| from .types import StringUUID | |||
| from .workflow import WorkflowRunStatus | |||
| @@ -1270,7 +1270,7 @@ class MessageFile(Base): | |||
| url: str | None = None, | |||
| belongs_to: Literal["user", "assistant"] | None = None, | |||
| upload_file_id: str | None = None, | |||
| created_by_role: CreatedByRole, | |||
| created_by_role: CreatorUserRole, | |||
| created_by: str, | |||
| ): | |||
| self.message_id = message_id | |||
| @@ -1417,7 +1417,7 @@ class EndUser(Base, UserMixin): | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) | |||
| app_id = db.Column(StringUUID, nullable=True) | |||
| type = db.Column(db.String(255), nullable=False) | |||
| external_user_id = db.Column(db.String(255), nullable=True) | |||
| @@ -1547,7 +1547,7 @@ class UploadFile(Base): | |||
| size: int, | |||
| extension: str, | |||
| mime_type: str, | |||
| created_by_role: CreatedByRole, | |||
| created_by_role: CreatorUserRole, | |||
| created_by: str, | |||
| created_at: datetime, | |||
| used: bool, | |||
| @@ -22,7 +22,7 @@ from libs import helper | |||
| from .account import Account | |||
| from .base import Base | |||
| from .engine import db | |||
| from .enums import CreatedByRole | |||
| from .enums import CreatorUserRole | |||
| from .types import StringUUID | |||
| if TYPE_CHECKING: | |||
| @@ -429,15 +429,15 @@ class WorkflowRun(Base): | |||
| @property | |||
| def created_by_account(self): | |||
| created_by_role = CreatedByRole(self.created_by_role) | |||
| return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None | |||
| created_by_role = CreatorUserRole(self.created_by_role) | |||
| return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None | |||
| @property | |||
| def created_by_end_user(self): | |||
| from models.model import EndUser | |||
| created_by_role = CreatedByRole(self.created_by_role) | |||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None | |||
| created_by_role = CreatorUserRole(self.created_by_role) | |||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None | |||
| @property | |||
| def graph_dict(self): | |||
| @@ -634,17 +634,17 @@ class WorkflowNodeExecution(Base): | |||
| @property | |||
| def created_by_account(self): | |||
| created_by_role = CreatedByRole(self.created_by_role) | |||
| created_by_role = CreatorUserRole(self.created_by_role) | |||
| # TODO(-LAN-): Avoid using db.session.get() here. | |||
| return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None | |||
| return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None | |||
| @property | |||
| def created_by_end_user(self): | |||
| from models.model import EndUser | |||
| created_by_role = CreatedByRole(self.created_by_role) | |||
| created_by_role = CreatorUserRole(self.created_by_role) | |||
| # TODO(-LAN-): Avoid using db.session.get() here. | |||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None | |||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None | |||
| @property | |||
| def inputs_dict(self): | |||
| @@ -755,15 +755,15 @@ class WorkflowAppLog(Base): | |||
| @property | |||
| def created_by_account(self): | |||
| created_by_role = CreatedByRole(self.created_by_role) | |||
| return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None | |||
| created_by_role = CreatorUserRole(self.created_by_role) | |||
| return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None | |||
| @property | |||
| def created_by_end_user(self): | |||
| from models.model import EndUser | |||
| created_by_role = CreatedByRole(self.created_by_role) | |||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None | |||
| created_by_role = CreatorUserRole(self.created_by_role) | |||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None | |||
| class ConversationVariable(Base): | |||
| @@ -19,7 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from models.account import Account | |||
| from models.enums import CreatedByRole | |||
| from models.enums import CreatorUserRole | |||
| from models.model import EndUser, UploadFile | |||
| from .errors.file import FileTooLargeError, UnsupportedFileTypeError | |||
| @@ -81,7 +81,7 @@ class FileService: | |||
| size=file_size, | |||
| extension=extension, | |||
| mime_type=mimetype, | |||
| created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), | |||
| created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER), | |||
| created_by=user.id, | |||
| created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | |||
| used=False, | |||
| @@ -133,7 +133,7 @@ class FileService: | |||
| extension="txt", | |||
| mime_type="text/plain", | |||
| created_by=current_user.id, | |||
| created_by_role=CreatedByRole.ACCOUNT, | |||
| created_by_role=CreatorUserRole.ACCOUNT, | |||
| created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | |||
| used=True, | |||
| used_by=current_user.id, | |||
| @@ -5,7 +5,7 @@ from sqlalchemy import and_, func, or_, select | |||
| from sqlalchemy.orm import Session | |||
| from models import App, EndUser, WorkflowAppLog, WorkflowRun | |||
| from models.enums import CreatedByRole | |||
| from models.enums import CreatorUserRole | |||
| from models.workflow import WorkflowRunStatus | |||
| @@ -58,7 +58,7 @@ class WorkflowAppService: | |||
| stmt = stmt.outerjoin( | |||
| EndUser, | |||
| and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER), | |||
| and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatorUserRole.END_USER), | |||
| ).where(or_(*keyword_conditions)) | |||
| if status: | |||
| @@ -1,4 +1,5 @@ | |||
| import threading | |||
| from collections.abc import Sequence | |||
| from typing import Optional | |||
| import contexts | |||
| @@ -6,11 +7,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig | |||
| from extensions.ext_database import db | |||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| from models.model import App | |||
| from models.workflow import ( | |||
| from models import ( | |||
| Account, | |||
| App, | |||
| EndUser, | |||
| WorkflowNodeExecution, | |||
| WorkflowRun, | |||
| WorkflowRunTriggeredFrom, | |||
| ) | |||
| @@ -116,7 +119,12 @@ class WorkflowRunService: | |||
| return workflow_run | |||
| def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]: | |||
| def get_workflow_run_node_executions( | |||
| self, | |||
| app_model: App, | |||
| run_id: str, | |||
| user: Account | EndUser, | |||
| ) -> Sequence[WorkflowNodeExecution]: | |||
| """ | |||
| Get workflow run node execution list | |||
| """ | |||
| @@ -128,13 +136,15 @@ class WorkflowRunService: | |||
| if not workflow_run: | |||
| return [] | |||
| # Use the repository to get the node executions | |||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id | |||
| session_factory=db.engine, | |||
| user=user, | |||
| app_id=app_model.id, | |||
| triggered_from=None, | |||
| ) | |||
| # Use the repository to get the node executions with ordering | |||
| order_config = OrderConfig(order_by=["index"], order_direction="desc") | |||
| node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) | |||
| node_executions = repository.get_db_models_by_workflow_run(workflow_run_id=run_id, order_config=order_config) | |||
| return list(node_executions) | |||
| return node_executions | |||
| @@ -26,7 +26,7 @@ from core.workflow.workflow_entry import WorkflowEntry | |||
| from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.enums import CreatedByRole | |||
| from models.enums import CreatorUserRole | |||
| from models.model import App, AppMode | |||
| from models.tools import WorkflowToolProvider | |||
| from models.workflow import ( | |||
| @@ -284,9 +284,11 @@ class WorkflowService: | |||
| workflow_node_execution.created_by = account.id | |||
| workflow_node_execution.workflow_id = draft_workflow.id | |||
| # Use the repository to save the workflow node execution | |||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id | |||
| session_factory=db.engine, | |||
| user=account, | |||
| app_id=app_model.id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| repository.save(workflow_node_execution) | |||
| @@ -390,7 +392,7 @@ class WorkflowService: | |||
| workflow_node_execution.node_type = node_instance.node_type | |||
| workflow_node_execution.title = node_instance.node_data.title | |||
| workflow_node_execution.elapsed_time = time.perf_counter() - start_at | |||
| workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value | |||
| workflow_node_execution.created_by_role = CreatorUserRole.ACCOUNT.value | |||
| workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| if run_succeeded and node_run_result: | |||
| @@ -4,16 +4,19 @@ from collections.abc import Callable | |||
| import click | |||
| from celery import shared_task # type: ignore | |||
| from sqlalchemy import delete | |||
| from sqlalchemy import delete, select | |||
| from sqlalchemy.exc import SQLAlchemyError | |||
| from sqlalchemy.orm import Session | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from extensions.ext_database import db | |||
| from models.dataset import AppDatasetJoin | |||
| from models.model import ( | |||
| from models import ( | |||
| Account, | |||
| ApiToken, | |||
| App, | |||
| AppAnnotationHitHistory, | |||
| AppAnnotationSetting, | |||
| AppDatasetJoin, | |||
| AppModelConfig, | |||
| Conversation, | |||
| EndUser, | |||
| @@ -188,9 +191,24 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): | |||
| def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): | |||
| # Get app's owner | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id) | |||
| user = session.scalar(stmt) | |||
| if user is None: | |||
| errmsg = ( | |||
| f"Failed to delete workflow node executions for tenant {tenant_id} and app {app_id}, app's owner not found" | |||
| ) | |||
| logging.error(errmsg) | |||
| raise ValueError(errmsg) | |||
| # Create a repository instance for WorkflowNodeExecution | |||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=db.engine, tenant_id=tenant_id, app_id=app_id | |||
| session_factory=db.engine, | |||
| user=user, | |||
| app_id=app_id, | |||
| triggered_from=None, | |||
| ) | |||
| # Use the clear method to delete all records for this tenant_id and app_id | |||
| @@ -16,10 +16,9 @@ from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||
| from models.enums import CreatedByRole | |||
| from models.enums import CreatorUserRole | |||
| from models.workflow import ( | |||
| Workflow, | |||
| WorkflowNodeExecution, | |||
| WorkflowNodeExecutionStatus, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| @@ -94,7 +93,7 @@ def mock_workflow_run(): | |||
| workflow_run.app_id = "test-app-id" | |||
| workflow_run.workflow_id = "test-workflow-id" | |||
| workflow_run.status = WorkflowRunStatus.RUNNING | |||
| workflow_run.created_by_role = CreatedByRole.ACCOUNT | |||
| workflow_run.created_by_role = CreatorUserRole.ACCOUNT | |||
| workflow_run.created_by = "test-user-id" | |||
| workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_run.inputs_dict = {"query": "test query"} | |||
| @@ -107,7 +106,6 @@ def test_init( | |||
| ): | |||
| """Test initialization of WorkflowCycleManager""" | |||
| assert workflow_cycle_manager._workflow_run is None | |||
| assert workflow_cycle_manager._workflow_node_executions == {} | |||
| assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity | |||
| assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables | |||
| assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository | |||
| @@ -123,7 +121,7 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo | |||
| session=mock_session, | |||
| workflow_id="test-workflow-id", | |||
| user_id="test-user-id", | |||
| created_by_role=CreatedByRole.ACCOUNT, | |||
| created_by_role=CreatorUserRole.ACCOUNT, | |||
| ) | |||
| # Verify the result | |||
| @@ -132,7 +130,7 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo | |||
| assert workflow_run.workflow_id == mock_workflow.id | |||
| assert workflow_run.sequence_number == 6 # max_sequence + 1 | |||
| assert workflow_run.status == WorkflowRunStatus.RUNNING | |||
| assert workflow_run.created_by_role == CreatedByRole.ACCOUNT | |||
| assert workflow_run.created_by_role == CreatorUserRole.ACCOUNT | |||
| assert workflow_run.created_by == "test-user-id" | |||
| # Verify session.add was called | |||
| @@ -215,24 +213,23 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run): | |||
| ) | |||
| # Verify the result | |||
| assert result.tenant_id == mock_workflow_run.tenant_id | |||
| assert result.app_id == mock_workflow_run.app_id | |||
| # NodeExecution doesn't have tenant_id attribute, it's handled at repository level | |||
| # assert result.tenant_id == mock_workflow_run.tenant_id | |||
| # assert result.app_id == mock_workflow_run.app_id | |||
| assert result.workflow_id == mock_workflow_run.workflow_id | |||
| assert result.workflow_run_id == mock_workflow_run.id | |||
| assert result.node_execution_id == event.node_execution_id | |||
| assert result.node_id == event.node_id | |||
| assert result.node_type == event.node_type.value | |||
| assert result.node_type == event.node_type | |||
| assert result.title == event.node_data.title | |||
| assert result.status == WorkflowNodeExecutionStatus.RUNNING.value | |||
| assert result.created_by_role == mock_workflow_run.created_by_role | |||
| assert result.created_by == mock_workflow_run.created_by | |||
| # NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level | |||
| # assert result.created_by_role == mock_workflow_run.created_by_role | |||
| # assert result.created_by == mock_workflow_run.created_by | |||
| # Verify save was called | |||
| workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result) | |||
| # Verify the node execution was added to the cache | |||
| assert workflow_cycle_manager._workflow_node_executions[event.node_execution_id] == result | |||
| def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run): | |||
| """Test _get_workflow_run method""" | |||
| @@ -261,28 +258,24 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager): | |||
| event.execution_metadata = {"metadata": "test metadata"} | |||
| event.start_at = datetime.now(UTC).replace(tzinfo=None) | |||
| # Create a mock workflow node execution | |||
| node_execution = MagicMock(spec=WorkflowNodeExecution) | |||
| # Create a mock node execution | |||
| node_execution = MagicMock() | |||
| node_execution.node_execution_id = "test-node-execution-id" | |||
| # Mock _get_workflow_node_execution to return the mock node execution | |||
| with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_node_execution_success( | |||
| event=event, | |||
| ) | |||
| # Mock the repository to return the node execution | |||
| workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution | |||
| # Verify the result | |||
| assert result == node_execution | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value | |||
| assert result.inputs == json.dumps(event.inputs) | |||
| assert result.process_data == json.dumps(event.process_data) | |||
| assert result.outputs == json.dumps(event.outputs) | |||
| assert result.finished_at is not None | |||
| assert result.elapsed_time is not None | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_node_execution_success( | |||
| event=event, | |||
| ) | |||
| # Verify update was called | |||
| workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) | |||
| # Verify the result | |||
| assert result == node_execution | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value | |||
| # Verify save was called | |||
| workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) | |||
| def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run): | |||
| @@ -322,27 +315,22 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager): | |||
| event.start_at = datetime.now(UTC).replace(tzinfo=None) | |||
| event.error = "Test error message" | |||
| # Create a mock workflow node execution | |||
| node_execution = MagicMock(spec=WorkflowNodeExecution) | |||
| # Create a mock node execution | |||
| node_execution = MagicMock() | |||
| node_execution.node_execution_id = "test-node-execution-id" | |||
| # Mock _get_workflow_node_execution to return the mock node execution | |||
| with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_node_execution_failed( | |||
| event=event, | |||
| ) | |||
| # Mock the repository to return the node execution | |||
| workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution | |||
| # Verify the result | |||
| assert result == node_execution | |||
| assert result.status == WorkflowNodeExecutionStatus.FAILED.value | |||
| assert result.error == "Test error message" | |||
| assert result.inputs == json.dumps(event.inputs) | |||
| assert result.process_data == json.dumps(event.process_data) | |||
| assert result.outputs == json.dumps(event.outputs) | |||
| assert result.finished_at is not None | |||
| assert result.elapsed_time is not None | |||
| assert result.execution_metadata == json.dumps(event.execution_metadata) | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_node_execution_failed( | |||
| event=event, | |||
| ) | |||
| # Verify update was called | |||
| workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) | |||
| # Verify the result | |||
| assert result == node_execution | |||
| assert result.status == WorkflowNodeExecutionStatus.FAILED.value | |||
| assert result.error == "Test error message" | |||
| # Verify save was called | |||
| workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) | |||
| @@ -2,15 +2,36 @@ | |||
| Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository. | |||
| """ | |||
| from unittest.mock import MagicMock | |||
| import json | |||
| from datetime import datetime | |||
| from unittest.mock import MagicMock, PropertyMock | |||
| import pytest | |||
| from pytest_mock import MockerFixture | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig | |||
| from models.workflow import WorkflowNodeExecution | |||
| from models.account import Account, Tenant | |||
| from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom | |||
| def configure_mock_execution(mock_execution): | |||
| """Configure a mock execution with proper JSON serializable values.""" | |||
| # Configure inputs, outputs, process_data, and execution_metadata to return JSON serializable values | |||
| type(mock_execution).inputs = PropertyMock(return_value='{"key": "value"}') | |||
| type(mock_execution).outputs = PropertyMock(return_value='{"result": "success"}') | |||
| type(mock_execution).process_data = PropertyMock(return_value='{"process": "data"}') | |||
| type(mock_execution).execution_metadata = PropertyMock(return_value='{"metadata": "info"}') | |||
| # Configure status and triggered_from to be valid enum values | |||
| mock_execution.status = "running" | |||
| mock_execution.triggered_from = "workflow-run" | |||
| return mock_execution | |||
| @pytest.fixture | |||
| @@ -28,13 +49,30 @@ def session(): | |||
| @pytest.fixture | |||
| def repository(session): | |||
| def mock_user(): | |||
| """Create a user instance for testing.""" | |||
| user = Account() | |||
| user.id = "test-user-id" | |||
| tenant = Tenant() | |||
| tenant.id = "test-tenant" | |||
| tenant.name = "Test Workspace" | |||
| user._current_tenant = MagicMock() | |||
| user._current_tenant.id = "test-tenant" | |||
| return user | |||
| @pytest.fixture | |||
| def repository(session, mock_user): | |||
| """Create a repository instance with test data.""" | |||
| _, session_factory = session | |||
| tenant_id = "test-tenant" | |||
| app_id = "test-app" | |||
| return SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, tenant_id=tenant_id, app_id=app_id | |||
| session_factory=session_factory, | |||
| user=mock_user, | |||
| app_id=app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| @@ -45,16 +83,23 @@ def test_save(repository, session): | |||
| execution = MagicMock(spec=WorkflowNodeExecution) | |||
| execution.tenant_id = None | |||
| execution.app_id = None | |||
| execution.inputs = None | |||
| execution.process_data = None | |||
| execution.outputs = None | |||
| execution.metadata = None | |||
| # Mock the _to_db_model method to return the execution itself | |||
| # This simulates the behavior of setting tenant_id and app_id | |||
| repository._to_db_model = MagicMock(return_value=execution) | |||
| # Call save method | |||
| repository.save(execution) | |||
| # Assert tenant_id and app_id are set | |||
| assert execution.tenant_id == repository._tenant_id | |||
| assert execution.app_id == repository._app_id | |||
| # Assert _to_db_model was called with the execution | |||
| repository._to_db_model.assert_called_once_with(execution) | |||
| # Assert session.add was called | |||
| session_obj.add.assert_called_once_with(execution) | |||
| # Assert session.merge was called (now using merge for both save and update) | |||
| session_obj.merge.assert_called_once_with(execution) | |||
| def test_save_with_existing_tenant_id(repository, session): | |||
| @@ -64,16 +109,27 @@ def test_save_with_existing_tenant_id(repository, session): | |||
| execution = MagicMock(spec=WorkflowNodeExecution) | |||
| execution.tenant_id = "existing-tenant" | |||
| execution.app_id = None | |||
| execution.inputs = None | |||
| execution.process_data = None | |||
| execution.outputs = None | |||
| execution.metadata = None | |||
| # Create a modified execution that will be returned by _to_db_model | |||
| modified_execution = MagicMock(spec=WorkflowNodeExecution) | |||
| modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change | |||
| modified_execution.app_id = repository._app_id # App ID should be set | |||
| # Mock the _to_db_model method to return the modified execution | |||
| repository._to_db_model = MagicMock(return_value=modified_execution) | |||
| # Call save method | |||
| repository.save(execution) | |||
| # Assert tenant_id is not changed and app_id is set | |||
| assert execution.tenant_id == "existing-tenant" | |||
| assert execution.app_id == repository._app_id | |||
| # Assert _to_db_model was called with the execution | |||
| repository._to_db_model.assert_called_once_with(execution) | |||
| # Assert session.add was called | |||
| session_obj.add.assert_called_once_with(execution) | |||
| # Assert session.merge was called with the modified execution (now using merge for both save and update) | |||
| session_obj.merge.assert_called_once_with(modified_execution) | |||
| def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): | |||
| @@ -84,7 +140,16 @@ def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): | |||
| mock_stmt = mocker.MagicMock() | |||
| mock_select.return_value = mock_stmt | |||
| mock_stmt.where.return_value = mock_stmt | |||
| session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution) | |||
| # Create a properly configured mock execution | |||
| mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) | |||
| configure_mock_execution(mock_execution) | |||
| session_obj.scalar.return_value = mock_execution | |||
| # Create a mock domain model to be returned by _to_domain_model | |||
| mock_domain_model = mocker.MagicMock() | |||
| # Mock the _to_domain_model method to return our mock domain model | |||
| repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) | |||
| # Call method | |||
| result = repository.get_by_node_execution_id("test-node-execution-id") | |||
| @@ -92,7 +157,10 @@ def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): | |||
| # Assert select was called with correct parameters | |||
| mock_select.assert_called_once() | |||
| session_obj.scalar.assert_called_once_with(mock_stmt) | |||
| assert result is not None | |||
| # Assert _to_domain_model was called with the mock execution | |||
| repository._to_domain_model.assert_called_once_with(mock_execution) | |||
| # Assert the result is our mock domain model | |||
| assert result is mock_domain_model | |||
| def test_get_by_workflow_run(repository, session, mocker: MockerFixture): | |||
| @@ -104,7 +172,16 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture): | |||
| mock_select.return_value = mock_stmt | |||
| mock_stmt.where.return_value = mock_stmt | |||
| mock_stmt.order_by.return_value = mock_stmt | |||
| session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] | |||
| # Create a properly configured mock execution | |||
| mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) | |||
| configure_mock_execution(mock_execution) | |||
| session_obj.scalars.return_value.all.return_value = [mock_execution] | |||
| # Create a mock domain model to be returned by _to_domain_model | |||
| mock_domain_model = mocker.MagicMock() | |||
| # Mock the _to_domain_model method to return our mock domain model | |||
| repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) | |||
| # Call method | |||
| order_config = OrderConfig(order_by=["index"], order_direction="desc") | |||
| @@ -113,7 +190,45 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture): | |||
| # Assert select was called with correct parameters | |||
| mock_select.assert_called_once() | |||
| session_obj.scalars.assert_called_once_with(mock_stmt) | |||
| # Assert _to_domain_model was called with the mock execution | |||
| repository._to_domain_model.assert_called_once_with(mock_execution) | |||
| # Assert the result contains our mock domain model | |||
| assert len(result) == 1 | |||
| assert result[0] is mock_domain_model | |||
| def test_get_db_models_by_workflow_run(repository, session, mocker: MockerFixture): | |||
| """Test get_db_models_by_workflow_run method.""" | |||
| session_obj, _ = session | |||
| # Set up mock | |||
| mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") | |||
| mock_stmt = mocker.MagicMock() | |||
| mock_select.return_value = mock_stmt | |||
| mock_stmt.where.return_value = mock_stmt | |||
| mock_stmt.order_by.return_value = mock_stmt | |||
| # Create a properly configured mock execution | |||
| mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) | |||
| configure_mock_execution(mock_execution) | |||
| session_obj.scalars.return_value.all.return_value = [mock_execution] | |||
| # Mock the _to_domain_model method | |||
| to_domain_model_mock = mocker.patch.object(repository, "_to_domain_model") | |||
| # Call method | |||
| order_config = OrderConfig(order_by=["index"], order_direction="desc") | |||
| result = repository.get_db_models_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config) | |||
| # Assert select was called with correct parameters | |||
| mock_select.assert_called_once() | |||
| session_obj.scalars.assert_called_once_with(mock_stmt) | |||
| # Assert the result contains our mock db model directly (without conversion to domain model) | |||
| assert len(result) == 1 | |||
| assert result[0] is mock_execution | |||
| # Verify that _to_domain_model was NOT called (since we're returning raw DB models) | |||
| to_domain_model_mock.assert_not_called() | |||
| def test_get_running_executions(repository, session, mocker: MockerFixture): | |||
| @@ -124,7 +239,16 @@ def test_get_running_executions(repository, session, mocker: MockerFixture): | |||
| mock_stmt = mocker.MagicMock() | |||
| mock_select.return_value = mock_stmt | |||
| mock_stmt.where.return_value = mock_stmt | |||
| session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] | |||
| # Create a properly configured mock execution | |||
| mock_execution = mocker.MagicMock(spec=WorkflowNodeExecution) | |||
| configure_mock_execution(mock_execution) | |||
| session_obj.scalars.return_value.all.return_value = [mock_execution] | |||
| # Create a mock domain model to be returned by _to_domain_model | |||
| mock_domain_model = mocker.MagicMock() | |||
| # Mock the _to_domain_model method to return our mock domain model | |||
| repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) | |||
| # Call method | |||
| result = repository.get_running_executions("test-workflow-run-id") | |||
| @@ -132,25 +256,36 @@ def test_get_running_executions(repository, session, mocker: MockerFixture): | |||
| # Assert select was called with correct parameters | |||
| mock_select.assert_called_once() | |||
| session_obj.scalars.assert_called_once_with(mock_stmt) | |||
| # Assert _to_domain_model was called with the mock execution | |||
| repository._to_domain_model.assert_called_once_with(mock_execution) | |||
| # Assert the result contains our mock domain model | |||
| assert len(result) == 1 | |||
| assert result[0] is mock_domain_model | |||
| def test_update(repository, session): | |||
| """Test update method.""" | |||
| def test_update_via_save(repository, session): | |||
| """Test updating an existing record via save method.""" | |||
| session_obj, _ = session | |||
| # Create a mock execution | |||
| execution = MagicMock(spec=WorkflowNodeExecution) | |||
| execution.tenant_id = None | |||
| execution.app_id = None | |||
| execution.inputs = None | |||
| execution.process_data = None | |||
| execution.outputs = None | |||
| execution.metadata = None | |||
| # Call update method | |||
| repository.update(execution) | |||
| # Mock the _to_db_model method to return the execution itself | |||
| # This simulates the behavior of setting tenant_id and app_id | |||
| repository._to_db_model = MagicMock(return_value=execution) | |||
| # Assert tenant_id and app_id are set | |||
| assert execution.tenant_id == repository._tenant_id | |||
| assert execution.app_id == repository._app_id | |||
| # Call save method to update an existing record | |||
| repository.save(execution) | |||
| # Assert session.merge was called | |||
| # Assert _to_db_model was called with the execution | |||
| repository._to_db_model.assert_called_once_with(execution) | |||
| # Assert session.merge was called (for updates) | |||
| session_obj.merge.assert_called_once_with(execution) | |||
| @@ -176,3 +311,118 @@ def test_clear(repository, session, mocker: MockerFixture): | |||
| mock_stmt.where.assert_called() | |||
| session_obj.execute.assert_called_once_with(mock_stmt) | |||
| session_obj.commit.assert_called_once() | |||
| def test_to_db_model(repository): | |||
| """Test _to_db_model method.""" | |||
| # Create a domain model | |||
| domain_model = NodeExecution( | |||
| id="test-id", | |||
| workflow_id="test-workflow-id", | |||
| node_execution_id="test-node-execution-id", | |||
| workflow_run_id="test-workflow-run-id", | |||
| index=1, | |||
| predecessor_node_id="test-predecessor-id", | |||
| node_id="test-node-id", | |||
| node_type=NodeType.START, | |||
| title="Test Node", | |||
| inputs={"input_key": "input_value"}, | |||
| process_data={"process_key": "process_value"}, | |||
| outputs={"output_key": "output_value"}, | |||
| status=NodeExecutionStatus.RUNNING, | |||
| error=None, | |||
| elapsed_time=1.5, | |||
| metadata={NodeRunMetadataKey.TOTAL_TOKENS: 100}, | |||
| created_at=datetime.now(), | |||
| finished_at=None, | |||
| ) | |||
| # Convert to DB model | |||
| db_model = repository._to_db_model(domain_model) | |||
| # Assert DB model has correct values | |||
| assert isinstance(db_model, WorkflowNodeExecution) | |||
| assert db_model.id == domain_model.id | |||
| assert db_model.tenant_id == repository._tenant_id | |||
| assert db_model.app_id == repository._app_id | |||
| assert db_model.workflow_id == domain_model.workflow_id | |||
| assert db_model.triggered_from == repository._triggered_from | |||
| assert db_model.workflow_run_id == domain_model.workflow_run_id | |||
| assert db_model.index == domain_model.index | |||
| assert db_model.predecessor_node_id == domain_model.predecessor_node_id | |||
| assert db_model.node_execution_id == domain_model.node_execution_id | |||
| assert db_model.node_id == domain_model.node_id | |||
| assert db_model.node_type == domain_model.node_type | |||
| assert db_model.title == domain_model.title | |||
| assert db_model.inputs_dict == domain_model.inputs | |||
| assert db_model.process_data_dict == domain_model.process_data | |||
| assert db_model.outputs_dict == domain_model.outputs | |||
| assert db_model.execution_metadata_dict == domain_model.metadata | |||
| assert db_model.status == domain_model.status | |||
| assert db_model.error == domain_model.error | |||
| assert db_model.elapsed_time == domain_model.elapsed_time | |||
| assert db_model.created_at == domain_model.created_at | |||
| assert db_model.created_by_role == repository._creator_user_role | |||
| assert db_model.created_by == repository._creator_user_id | |||
| assert db_model.finished_at == domain_model.finished_at | |||
| def test_to_domain_model(repository): | |||
| """Test _to_domain_model method.""" | |||
| # Create input dictionaries | |||
| inputs_dict = {"input_key": "input_value"} | |||
| process_data_dict = {"process_key": "process_value"} | |||
| outputs_dict = {"output_key": "output_value"} | |||
| metadata_dict = {str(NodeRunMetadataKey.TOTAL_TOKENS): 100} | |||
| # Create a DB model using our custom subclass | |||
| db_model = WorkflowNodeExecution() | |||
| db_model.id = "test-id" | |||
| db_model.tenant_id = "test-tenant-id" | |||
| db_model.app_id = "test-app-id" | |||
| db_model.workflow_id = "test-workflow-id" | |||
| db_model.triggered_from = "workflow-run" | |||
| db_model.workflow_run_id = "test-workflow-run-id" | |||
| db_model.index = 1 | |||
| db_model.predecessor_node_id = "test-predecessor-id" | |||
| db_model.node_execution_id = "test-node-execution-id" | |||
| db_model.node_id = "test-node-id" | |||
| db_model.node_type = NodeType.START.value | |||
| db_model.title = "Test Node" | |||
| db_model.inputs = json.dumps(inputs_dict) | |||
| db_model.process_data = json.dumps(process_data_dict) | |||
| db_model.outputs = json.dumps(outputs_dict) | |||
| db_model.status = WorkflowNodeExecutionStatus.RUNNING | |||
| db_model.error = None | |||
| db_model.elapsed_time = 1.5 | |||
| db_model.execution_metadata = json.dumps(metadata_dict) | |||
| db_model.created_at = datetime.now() | |||
| db_model.created_by_role = "account" | |||
| db_model.created_by = "test-user-id" | |||
| db_model.finished_at = None | |||
| # Convert to domain model | |||
| domain_model = repository._to_domain_model(db_model) | |||
| # Assert domain model has correct values | |||
| assert isinstance(domain_model, NodeExecution) | |||
| assert domain_model.id == db_model.id | |||
| assert domain_model.workflow_id == db_model.workflow_id | |||
| assert domain_model.workflow_run_id == db_model.workflow_run_id | |||
| assert domain_model.index == db_model.index | |||
| assert domain_model.predecessor_node_id == db_model.predecessor_node_id | |||
| assert domain_model.node_execution_id == db_model.node_execution_id | |||
| assert domain_model.node_id == db_model.node_id | |||
| assert domain_model.node_type == NodeType(db_model.node_type) | |||
| assert domain_model.title == db_model.title | |||
| assert domain_model.inputs == inputs_dict | |||
| assert domain_model.process_data == process_data_dict | |||
| assert domain_model.outputs == outputs_dict | |||
| assert domain_model.status == NodeExecutionStatus(db_model.status) | |||
| assert domain_model.error == db_model.error | |||
| assert domain_model.elapsed_time == db_model.elapsed_time | |||
| assert domain_model.metadata == metadata_dict | |||
| assert domain_model.created_at == db_model.created_at | |||
| assert domain_model.finished_at == db_model.finished_at | |||