Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.0.1
| @@ -1,8 +1,10 @@ | |||
| import json | |||
| import logging | |||
| from typing import cast | |||
| from flask import abort, request | |||
| from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | |||
| import services | |||
| @@ -13,6 +15,7 @@ from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from extensions.ext_database import db | |||
| from factories import variable_factory | |||
| from fields.workflow_fields import workflow_fields, workflow_pagination_fields | |||
| from fields.workflow_run_fields import workflow_run_node_execution_fields | |||
| @@ -24,7 +27,7 @@ from models.account import Account | |||
| from models.model import AppMode | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.app import WorkflowHashNotEqualError | |||
| from services.workflow_service import WorkflowService | |||
| from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService | |||
| logger = logging.getLogger(__name__) | |||
| @@ -439,10 +442,38 @@ class PublishedWorkflowApi(Resource): | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("marked_name", type=str, required=False, default="", location="json") | |||
| parser.add_argument("marked_comment", type=str, required=False, default="", location="json") | |||
| args = parser.parse_args() | |||
| # Validate name and comment length | |||
| if args.marked_name and len(args.marked_name) > 20: | |||
| raise ValueError("Marked name cannot exceed 20 characters") | |||
| if args.marked_comment and len(args.marked_comment) > 100: | |||
| raise ValueError("Marked comment cannot exceed 100 characters") | |||
| workflow_service = WorkflowService() | |||
| workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) | |||
| with Session(db.engine) as session: | |||
| workflow = workflow_service.publish_workflow( | |||
| session=session, | |||
| app_model=app_model, | |||
| account=current_user, | |||
| marked_name=args.marked_name or "", | |||
| marked_comment=args.marked_comment or "", | |||
| ) | |||
| app_model.workflow_id = workflow.id | |||
| db.session.commit() | |||
| workflow_created_at = TimestampField().format(workflow.created_at) | |||
| return {"result": "success", "created_at": TimestampField().format(workflow.created_at)} | |||
| session.commit() | |||
| return { | |||
| "result": "success", | |||
| "created_at": workflow_created_at, | |||
| } | |||
| class DefaultBlockConfigsApi(Resource): | |||
| @@ -564,37 +595,193 @@ class PublishedAllWorkflowApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") | |||
| parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") | |||
| parser.add_argument("user_id", type=str, required=False, location="args") | |||
| parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") | |||
| args = parser.parse_args() | |||
| page = int(args.get("page", 1)) | |||
| limit = int(args.get("limit", 10)) | |||
| user_id = args.get("user_id") | |||
| named_only = args.get("named_only", False) | |||
| if user_id: | |||
| if user_id != current_user.id: | |||
| raise Forbidden() | |||
| user_id = cast(str, user_id) | |||
| workflow_service = WorkflowService() | |||
| with Session(db.engine) as session: | |||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||
| session=session, | |||
| app_model=app_model, | |||
| page=page, | |||
| limit=limit, | |||
| user_id=user_id, | |||
| named_only=named_only, | |||
| ) | |||
| return { | |||
| "items": workflows, | |||
| "page": page, | |||
| "limit": limit, | |||
| "has_more": has_more, | |||
| } | |||
| class WorkflowByIdApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) | |||
| @marshal_with(workflow_fields) | |||
| def patch(self, app_model: App, workflow_id: str): | |||
| """ | |||
| Update workflow attributes | |||
| """ | |||
| # Check permission | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("marked_name", type=str, required=False, location="json") | |||
| parser.add_argument("marked_comment", type=str, required=False, location="json") | |||
| args = parser.parse_args() | |||
| page = args.get("page") | |||
| limit = args.get("limit") | |||
| # Validate name and comment length | |||
| if args.marked_name and len(args.marked_name) > 20: | |||
| raise ValueError("Marked name cannot exceed 20 characters") | |||
| if args.marked_comment and len(args.marked_comment) > 100: | |||
| raise ValueError("Marked comment cannot exceed 100 characters") | |||
| args = parser.parse_args() | |||
| # Prepare update data | |||
| update_data = {} | |||
| if args.get("marked_name") is not None: | |||
| update_data["marked_name"] = args["marked_name"] | |||
| if args.get("marked_comment") is not None: | |||
| update_data["marked_comment"] = args["marked_comment"] | |||
| if not update_data: | |||
| return {"message": "No valid fields to update"}, 400 | |||
| workflow_service = WorkflowService() | |||
| workflows, has_more = workflow_service.get_all_published_workflow(app_model=app_model, page=page, limit=limit) | |||
| return {"items": workflows, "page": page, "limit": limit, "has_more": has_more} | |||
| # Create a session and manage the transaction | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow = workflow_service.update_workflow( | |||
| session=session, | |||
| workflow_id=workflow_id, | |||
| tenant_id=app_model.tenant_id, | |||
| account_id=current_user.id, | |||
| data=update_data, | |||
| ) | |||
| if not workflow: | |||
| raise NotFound("Workflow not found") | |||
| # Commit the transaction in the controller | |||
| session.commit() | |||
| return workflow | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) | |||
| def delete(self, app_model: App, workflow_id: str): | |||
| """ | |||
| Delete workflow | |||
| """ | |||
| # Check permission | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| workflow_service = WorkflowService() | |||
| # Create a session and manage the transaction | |||
| with Session(db.engine) as session: | |||
| try: | |||
| workflow_service.delete_workflow( | |||
| session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id | |||
| ) | |||
| # Commit the transaction in the controller | |||
| session.commit() | |||
| except WorkflowInUseError as e: | |||
| abort(400, description=str(e)) | |||
| except DraftWorkflowDeletionError as e: | |||
| abort(400, description=str(e)) | |||
| except ValueError as e: | |||
| raise NotFound(str(e)) | |||
| return None, 204 | |||
| api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft") | |||
| api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config") | |||
| api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run") | |||
| api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run") | |||
| api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop") | |||
| api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run") | |||
| api.add_resource( | |||
| DraftWorkflowApi, | |||
| "/apps/<uuid:app_id>/workflows/draft", | |||
| ) | |||
| api.add_resource( | |||
| WorkflowConfigApi, | |||
| "/apps/<uuid:app_id>/workflows/draft/config", | |||
| ) | |||
| api.add_resource( | |||
| AdvancedChatDraftWorkflowRunApi, | |||
| "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run", | |||
| ) | |||
| api.add_resource( | |||
| DraftWorkflowRunApi, | |||
| "/apps/<uuid:app_id>/workflows/draft/run", | |||
| ) | |||
| api.add_resource( | |||
| WorkflowTaskStopApi, | |||
| "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop", | |||
| ) | |||
| api.add_resource( | |||
| DraftWorkflowNodeRunApi, | |||
| "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run", | |||
| ) | |||
| api.add_resource( | |||
| AdvancedChatDraftRunIterationNodeApi, | |||
| "/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run", | |||
| ) | |||
| api.add_resource( | |||
| WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run" | |||
| WorkflowDraftRunIterationNodeApi, | |||
| "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run", | |||
| ) | |||
| api.add_resource( | |||
| AdvancedChatDraftRunLoopNodeApi, | |||
| "/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run", | |||
| ) | |||
| api.add_resource(WorkflowDraftRunLoopNodeApi, "/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run") | |||
| api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish") | |||
| api.add_resource(PublishedAllWorkflowApi, "/apps/<uuid:app_id>/workflows") | |||
| api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs") | |||
| api.add_resource( | |||
| DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>" | |||
| WorkflowDraftRunLoopNodeApi, | |||
| "/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run", | |||
| ) | |||
| api.add_resource( | |||
| PublishedWorkflowApi, | |||
| "/apps/<uuid:app_id>/workflows/publish", | |||
| ) | |||
| api.add_resource( | |||
| PublishedAllWorkflowApi, | |||
| "/apps/<uuid:app_id>/workflows", | |||
| ) | |||
| api.add_resource( | |||
| DefaultBlockConfigsApi, | |||
| "/apps/<uuid:app_id>/workflows/default-workflow-block-configs", | |||
| ) | |||
| api.add_resource( | |||
| DefaultBlockConfigApi, | |||
| "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>", | |||
| ) | |||
| api.add_resource( | |||
| ConvertToWorkflowApi, | |||
| "/apps/<uuid:app_id>/convert-to-workflow", | |||
| ) | |||
| api.add_resource( | |||
| WorkflowByIdApi, | |||
| "/apps/<uuid:app_id>/workflows/<string:workflow_id>", | |||
| ) | |||
| api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow") | |||
| @@ -45,7 +45,9 @@ workflow_fields = { | |||
| "graph": fields.Raw(attribute="graph_dict"), | |||
| "features": fields.Raw(attribute="features_dict"), | |||
| "hash": fields.String(attribute="unique_hash"), | |||
| "version": fields.String(attribute="version"), | |||
| "version": fields.String, | |||
| "marked_name": fields.String, | |||
| "marked_comment": fields.String, | |||
| "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), | |||
| "created_at": TimestampField, | |||
| "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), | |||
| @@ -0,0 +1,29 @@ | |||
| """add marked_name and marked_comment in workflows | |||
| Revision ID: ee79d9b1c156 | |||
| Revises: 4413929e1ec2 | |||
| Create Date: 2025-03-03 14:36:05.750346 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'ee79d9b1c156' | |||
| down_revision = '5511c782ee4c' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| with op.batch_alter_table('workflows', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('marked_name', sa.String(), nullable=False, server_default='')) | |||
| batch_op.add_column(sa.Column('marked_comment', sa.String(), nullable=False, server_default='')) | |||
| def downgrade(): | |||
| with op.batch_alter_table('workflows', schema=None) as batch_op: | |||
| batch_op.drop_column('marked_comment') | |||
| batch_op.drop_column('marked_name') | |||
| @@ -2,7 +2,8 @@ import json | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import UTC, datetime | |||
| from enum import Enum | |||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||
| from typing import TYPE_CHECKING, Any, Optional, Self, Union | |||
| from uuid import uuid4 | |||
| if TYPE_CHECKING: | |||
| from models.model import AppMode | |||
| @@ -108,7 +109,9 @@ class Workflow(Base): | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| type: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| version: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| version: Mapped[str] | |||
| marked_name: Mapped[str] = mapped_column(default="", server_default="") | |||
| marked_comment: Mapped[str] = mapped_column(default="", server_default="") | |||
| graph: Mapped[str] = mapped_column(sa.Text) | |||
| _features: Mapped[str] = mapped_column("features", sa.TEXT) | |||
| created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| @@ -127,8 +130,9 @@ class Workflow(Base): | |||
| "conversation_variables", db.Text, nullable=False, server_default="{}" | |||
| ) | |||
| def __init__( | |||
| self, | |||
| @classmethod | |||
| def new( | |||
| cls, | |||
| *, | |||
| tenant_id: str, | |||
| app_id: str, | |||
| @@ -139,16 +143,25 @@ class Workflow(Base): | |||
| created_by: str, | |||
| environment_variables: Sequence[Variable], | |||
| conversation_variables: Sequence[Variable], | |||
| ): | |||
| self.tenant_id = tenant_id | |||
| self.app_id = app_id | |||
| self.type = type | |||
| self.version = version | |||
| self.graph = graph | |||
| self.features = features | |||
| self.created_by = created_by | |||
| self.environment_variables = environment_variables or [] | |||
| self.conversation_variables = conversation_variables or [] | |||
| marked_name: str = "", | |||
| marked_comment: str = "", | |||
| ) -> Self: | |||
| workflow = Workflow() | |||
| workflow.id = str(uuid4()) | |||
| workflow.tenant_id = tenant_id | |||
| workflow.app_id = app_id | |||
| workflow.type = type | |||
| workflow.version = version | |||
| workflow.graph = graph | |||
| workflow.features = features | |||
| workflow.created_by = created_by | |||
| workflow.environment_variables = environment_variables or [] | |||
| workflow.conversation_variables = conversation_variables or [] | |||
| workflow.marked_name = marked_name | |||
| workflow.marked_comment = marked_comment | |||
| workflow.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow.updated_at = workflow.created_at | |||
| return workflow | |||
| @property | |||
| def created_by_account(self): | |||
| @@ -151,6 +151,7 @@ pytest-benchmark = "~4.0.0" | |||
| pytest-env = "~1.1.3" | |||
| pytest-mock = "~3.14.0" | |||
| types-beautifulsoup4 = "~4.12.0.20241020" | |||
| types-deprecated = "~1.2.15.20250304" | |||
| types-flask-cors = "~5.0.0.20240902" | |||
| types-flask-migrate = "~4.1.0.20250112" | |||
| types-html5lib = "~1.1.11.20241018" | |||
| @@ -0,0 +1,10 @@ | |||
| class WorkflowInUseError(ValueError): | |||
| """Raised when attempting to delete a workflow that's in use by an app""" | |||
| pass | |||
| class DraftWorkflowDeletionError(ValueError): | |||
| """Raised when attempting to delete a draft workflow""" | |||
| pass | |||
| @@ -5,7 +5,8 @@ from datetime import UTC, datetime | |||
| from typing import Any, Optional | |||
| from uuid import uuid4 | |||
| from sqlalchemy import desc | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | |||
| @@ -36,6 +37,8 @@ from models.workflow import ( | |||
| from services.errors.app import WorkflowHashNotEqualError | |||
| from services.workflow.workflow_converter import WorkflowConverter | |||
| from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError | |||
| class WorkflowService: | |||
| """ | |||
| @@ -79,22 +82,38 @@ class WorkflowService: | |||
| return workflow | |||
| def get_all_published_workflow(self, app_model: App, page: int, limit: int) -> tuple[list[Workflow], bool]: | |||
| def get_all_published_workflow( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| app_model: App, | |||
| page: int, | |||
| limit: int, | |||
| user_id: str | None, | |||
| named_only: bool = False, | |||
| ) -> tuple[Sequence[Workflow], bool]: | |||
| """ | |||
| Get published workflow with pagination | |||
| """ | |||
| if not app_model.workflow_id: | |||
| return [], False | |||
| workflows = ( | |||
| db.session.query(Workflow) | |||
| .filter(Workflow.app_id == app_model.id) | |||
| .order_by(desc(Workflow.version)) | |||
| .offset((page - 1) * limit) | |||
| stmt = ( | |||
| select(Workflow) | |||
| .where(Workflow.app_id == app_model.id) | |||
| .order_by(Workflow.version.desc()) | |||
| .limit(limit + 1) | |||
| .all() | |||
| .offset((page - 1) * limit) | |||
| ) | |||
| if user_id: | |||
| stmt = stmt.where(Workflow.created_by == user_id) | |||
| if named_only: | |||
| stmt = stmt.where(Workflow.marked_name != "") | |||
| workflows = session.scalars(stmt).all() | |||
| has_more = len(workflows) > limit | |||
| if has_more: | |||
| workflows = workflows[:-1] | |||
| @@ -157,23 +176,26 @@ class WorkflowService: | |||
| # return draft workflow | |||
| return workflow | |||
| def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow: | |||
| """ | |||
| Publish workflow from draft | |||
| :param app_model: App instance | |||
| :param account: Account instance | |||
| :param draft_workflow: Workflow instance | |||
| """ | |||
| if not draft_workflow: | |||
| # fetch draft workflow by app_model | |||
| draft_workflow = self.get_draft_workflow(app_model=app_model) | |||
| def publish_workflow( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| app_model: App, | |||
| account: Account, | |||
| marked_name: str = "", | |||
| marked_comment: str = "", | |||
| ) -> Workflow: | |||
| draft_workflow_stmt = select(Workflow).where( | |||
| Workflow.tenant_id == app_model.tenant_id, | |||
| Workflow.app_id == app_model.id, | |||
| Workflow.version == "draft", | |||
| ) | |||
| draft_workflow = session.scalar(draft_workflow_stmt) | |||
| if not draft_workflow: | |||
| raise ValueError("No valid workflow found.") | |||
| # create new workflow | |||
| workflow = Workflow( | |||
| workflow = Workflow.new( | |||
| tenant_id=app_model.tenant_id, | |||
| app_id=app_model.id, | |||
| type=draft_workflow.type, | |||
| @@ -183,15 +205,12 @@ class WorkflowService: | |||
| created_by=account.id, | |||
| environment_variables=draft_workflow.environment_variables, | |||
| conversation_variables=draft_workflow.conversation_variables, | |||
| marked_name=marked_name, | |||
| marked_comment=marked_comment, | |||
| ) | |||
| # commit db session changes | |||
| db.session.add(workflow) | |||
| db.session.flush() | |||
| db.session.commit() | |||
| app_model.workflow_id = workflow.id | |||
| db.session.commit() | |||
| session.add(workflow) | |||
| # trigger app workflow events | |||
| app_published_workflow_was_updated.send(app_model, published_workflow=workflow) | |||
| @@ -436,3 +455,65 @@ class WorkflowService: | |||
| ) | |||
| else: | |||
| raise ValueError(f"Invalid app mode: {app_model.mode}") | |||
| def update_workflow( | |||
| self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict | |||
| ) -> Optional[Workflow]: | |||
| """ | |||
| Update workflow attributes | |||
| :param session: SQLAlchemy database session | |||
| :param workflow_id: Workflow ID | |||
| :param tenant_id: Tenant ID | |||
| :param account_id: Account ID (for permission check) | |||
| :param data: Dictionary containing fields to update | |||
| :return: Updated workflow or None if not found | |||
| """ | |||
| stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) | |||
| workflow = session.scalar(stmt) | |||
| if not workflow: | |||
| return None | |||
| allowed_fields = ["marked_name", "marked_comment"] | |||
| for field, value in data.items(): | |||
| if field in allowed_fields: | |||
| setattr(workflow, field, value) | |||
| workflow.updated_by = account_id | |||
| workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| return workflow | |||
| def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool: | |||
| """ | |||
| Delete a workflow | |||
| :param session: SQLAlchemy database session | |||
| :param workflow_id: Workflow ID | |||
| :param tenant_id: Tenant ID | |||
| :return: True if successful | |||
| :raises: ValueError if workflow not found | |||
| :raises: WorkflowInUseError if workflow is in use | |||
| :raises: DraftWorkflowDeletionError if workflow is a draft version | |||
| """ | |||
| stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) | |||
| workflow = session.scalar(stmt) | |||
| if not workflow: | |||
| raise ValueError(f"Workflow with ID {workflow_id} not found") | |||
| # Check if workflow is a draft version | |||
| if workflow.version == "draft": | |||
| raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") | |||
| # Check if this workflow is currently referenced by an app | |||
| stmt = select(App).where(App.workflow_id == workflow_id) | |||
| app = session.scalar(stmt) | |||
| if app: | |||
| # Cannot delete a workflow that's currently in use by an app | |||
| raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'") | |||
| session.delete(workflow) | |||
| return True | |||
| @@ -0,0 +1,162 @@ | |||
| from unittest.mock import MagicMock | |||
| import pytest | |||
| from models.model import App | |||
| from models.workflow import Workflow | |||
| from services.workflow_service import WorkflowService | |||
| class TestWorkflowService: | |||
| @pytest.fixture | |||
| def workflow_service(self): | |||
| return WorkflowService() | |||
| @pytest.fixture | |||
| def mock_app(self): | |||
| app = MagicMock(spec=App) | |||
| app.id = "app-id-1" | |||
| app.workflow_id = "workflow-id-1" | |||
| app.tenant_id = "tenant-id-1" | |||
| return app | |||
| @pytest.fixture | |||
| def mock_workflows(self): | |||
| workflows = [] | |||
| for i in range(5): | |||
| workflow = MagicMock(spec=Workflow) | |||
| workflow.id = f"workflow-id-{i}" | |||
| workflow.app_id = "app-id-1" | |||
| workflow.created_at = f"2023-01-0{5 - i}" # Descending date order | |||
| workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2" | |||
| workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else "" | |||
| workflows.append(workflow) | |||
| return workflows | |||
| def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app): | |||
| mock_app.workflow_id = None | |||
| mock_session = MagicMock() | |||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||
| session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None | |||
| ) | |||
| assert workflows == [] | |||
| assert has_more is False | |||
| mock_session.scalars.assert_not_called() | |||
| def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows): | |||
| mock_session = MagicMock() | |||
| mock_scalar_result = MagicMock() | |||
| mock_scalar_result.all.return_value = mock_workflows[:3] | |||
| mock_session.scalars.return_value = mock_scalar_result | |||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||
| session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None | |||
| ) | |||
| assert workflows == mock_workflows[:3] | |||
| assert has_more is False | |||
| mock_session.scalars.assert_called_once() | |||
| def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows): | |||
| mock_session = MagicMock() | |||
| mock_scalar_result = MagicMock() | |||
| # Return 4 items when limit is 3, which should indicate has_more=True | |||
| mock_scalar_result.all.return_value = mock_workflows[:4] | |||
| mock_session.scalars.return_value = mock_scalar_result | |||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||
| session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None | |||
| ) | |||
| # Should return only the first 3 items | |||
| assert len(workflows) == 3 | |||
| assert workflows == mock_workflows[:3] | |||
| assert has_more is True | |||
| # Test page 2 | |||
| mock_scalar_result.all.return_value = mock_workflows[3:] | |||
| mock_session.scalars.return_value = mock_scalar_result | |||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||
| session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None | |||
| ) | |||
| assert len(workflows) == 2 | |||
| assert has_more is False | |||
| def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows): | |||
| mock_session = MagicMock() | |||
| mock_scalar_result = MagicMock() | |||
| # Filter workflows for user-id-1 | |||
| filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"] | |||
| mock_scalar_result.all.return_value = filtered_workflows | |||
| mock_session.scalars.return_value = mock_scalar_result | |||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||
| session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1" | |||
| ) | |||
| assert workflows == filtered_workflows | |||
| assert has_more is False | |||
| mock_session.scalars.assert_called_once() | |||
| # Verify that the select contains a user filter clause | |||
| args = mock_session.scalars.call_args[0][0] | |||
| assert "created_by" in str(args) | |||
| def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows): | |||
| mock_session = MagicMock() | |||
| mock_scalar_result = MagicMock() | |||
| # Filter workflows that have a marked_name | |||
| named_workflows = [w for w in mock_workflows if w.marked_name] | |||
| mock_scalar_result.all.return_value = named_workflows | |||
| mock_session.scalars.return_value = mock_scalar_result | |||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||
| session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True | |||
| ) | |||
| assert workflows == named_workflows | |||
| assert has_more is False | |||
| mock_session.scalars.assert_called_once() | |||
| # Verify that the select contains a named_only filter clause | |||
| args = mock_session.scalars.call_args[0][0] | |||
| assert "marked_name !=" in str(args) | |||
| def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows): | |||
| mock_session = MagicMock() | |||
| mock_scalar_result = MagicMock() | |||
| # Combined filter: user-id-1 and has marked_name | |||
| filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name] | |||
| mock_scalar_result.all.return_value = filtered_workflows | |||
| mock_session.scalars.return_value = mock_scalar_result | |||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||
| session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True | |||
| ) | |||
| assert workflows == filtered_workflows | |||
| assert has_more is False | |||
| mock_session.scalars.assert_called_once() | |||
| # Verify that both filters are applied | |||
| args = mock_session.scalars.call_args[0][0] | |||
| assert "created_by" in str(args) | |||
| assert "marked_name !=" in str(args) | |||
| def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app): | |||
| mock_session = MagicMock() | |||
| mock_scalar_result = MagicMock() | |||
| mock_scalar_result.all.return_value = [] | |||
| mock_session.scalars.return_value = mock_scalar_result | |||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||
| session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None | |||
| ) | |||
| assert workflows == [] | |||
| assert has_more is False | |||
| mock_session.scalars.assert_called_once() | |||