Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.0.1
| import json | import json | ||||
| import logging | import logging | ||||
| from typing import cast | |||||
| from flask import abort, request | from flask import abort, request | ||||
| from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore | from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore | ||||
| from sqlalchemy.orm import Session | |||||
| from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | ||||
| import services | import services | ||||
| from controllers.console.wraps import account_initialization_required, setup_required | from controllers.console.wraps import account_initialization_required, setup_required | ||||
| from core.app.apps.base_app_queue_manager import AppQueueManager | from core.app.apps.base_app_queue_manager import AppQueueManager | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from extensions.ext_database import db | |||||
| from factories import variable_factory | from factories import variable_factory | ||||
| from fields.workflow_fields import workflow_fields, workflow_pagination_fields | from fields.workflow_fields import workflow_fields, workflow_pagination_fields | ||||
| from fields.workflow_run_fields import workflow_run_node_execution_fields | from fields.workflow_run_fields import workflow_run_node_execution_fields | ||||
| from models.model import AppMode | from models.model import AppMode | ||||
| from services.app_generate_service import AppGenerateService | from services.app_generate_service import AppGenerateService | ||||
| from services.errors.app import WorkflowHashNotEqualError | from services.errors.app import WorkflowHashNotEqualError | ||||
| from services.workflow_service import WorkflowService | |||||
| from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| if not isinstance(current_user, Account): | if not isinstance(current_user, Account): | ||||
| raise Forbidden() | raise Forbidden() | ||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("marked_name", type=str, required=False, default="", location="json") | |||||
| parser.add_argument("marked_comment", type=str, required=False, default="", location="json") | |||||
| args = parser.parse_args() | |||||
| # Validate name and comment length | |||||
| if args.marked_name and len(args.marked_name) > 20: | |||||
| raise ValueError("Marked name cannot exceed 20 characters") | |||||
| if args.marked_comment and len(args.marked_comment) > 100: | |||||
| raise ValueError("Marked comment cannot exceed 100 characters") | |||||
| workflow_service = WorkflowService() | workflow_service = WorkflowService() | ||||
| workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) | |||||
| with Session(db.engine) as session: | |||||
| workflow = workflow_service.publish_workflow( | |||||
| session=session, | |||||
| app_model=app_model, | |||||
| account=current_user, | |||||
| marked_name=args.marked_name or "", | |||||
| marked_comment=args.marked_comment or "", | |||||
| ) | |||||
| app_model.workflow_id = workflow.id | |||||
| db.session.commit() | |||||
| workflow_created_at = TimestampField().format(workflow.created_at) | |||||
| return {"result": "success", "created_at": TimestampField().format(workflow.created_at)} | |||||
| session.commit() | |||||
| return { | |||||
| "result": "success", | |||||
| "created_at": workflow_created_at, | |||||
| } | |||||
| class DefaultBlockConfigsApi(Resource): | class DefaultBlockConfigsApi(Resource): | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") | parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") | ||||
| parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") | parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") | ||||
| parser.add_argument("user_id", type=str, required=False, location="args") | |||||
| parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") | |||||
| args = parser.parse_args() | |||||
| page = int(args.get("page", 1)) | |||||
| limit = int(args.get("limit", 10)) | |||||
| user_id = args.get("user_id") | |||||
| named_only = args.get("named_only", False) | |||||
| if user_id: | |||||
| if user_id != current_user.id: | |||||
| raise Forbidden() | |||||
| user_id = cast(str, user_id) | |||||
| workflow_service = WorkflowService() | |||||
| with Session(db.engine) as session: | |||||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||||
| session=session, | |||||
| app_model=app_model, | |||||
| page=page, | |||||
| limit=limit, | |||||
| user_id=user_id, | |||||
| named_only=named_only, | |||||
| ) | |||||
| return { | |||||
| "items": workflows, | |||||
| "page": page, | |||||
| "limit": limit, | |||||
| "has_more": has_more, | |||||
| } | |||||
| class WorkflowByIdApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) | |||||
| @marshal_with(workflow_fields) | |||||
| def patch(self, app_model: App, workflow_id: str): | |||||
| """ | |||||
| Update workflow attributes | |||||
| """ | |||||
| # Check permission | |||||
| if not current_user.is_editor: | |||||
| raise Forbidden() | |||||
| if not isinstance(current_user, Account): | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("marked_name", type=str, required=False, location="json") | |||||
| parser.add_argument("marked_comment", type=str, required=False, location="json") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| page = args.get("page") | |||||
| limit = args.get("limit") | |||||
| # Validate name and comment length | |||||
| if args.marked_name and len(args.marked_name) > 20: | |||||
| raise ValueError("Marked name cannot exceed 20 characters") | |||||
| if args.marked_comment and len(args.marked_comment) > 100: | |||||
| raise ValueError("Marked comment cannot exceed 100 characters") | |||||
| args = parser.parse_args() | |||||
| # Prepare update data | |||||
| update_data = {} | |||||
| if args.get("marked_name") is not None: | |||||
| update_data["marked_name"] = args["marked_name"] | |||||
| if args.get("marked_comment") is not None: | |||||
| update_data["marked_comment"] = args["marked_comment"] | |||||
| if not update_data: | |||||
| return {"message": "No valid fields to update"}, 400 | |||||
| workflow_service = WorkflowService() | workflow_service = WorkflowService() | ||||
| workflows, has_more = workflow_service.get_all_published_workflow(app_model=app_model, page=page, limit=limit) | |||||
| return {"items": workflows, "page": page, "limit": limit, "has_more": has_more} | |||||
| # Create a session and manage the transaction | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow = workflow_service.update_workflow( | |||||
| session=session, | |||||
| workflow_id=workflow_id, | |||||
| tenant_id=app_model.tenant_id, | |||||
| account_id=current_user.id, | |||||
| data=update_data, | |||||
| ) | |||||
| if not workflow: | |||||
| raise NotFound("Workflow not found") | |||||
| # Commit the transaction in the controller | |||||
| session.commit() | |||||
| return workflow | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) | |||||
| def delete(self, app_model: App, workflow_id: str): | |||||
| """ | |||||
| Delete workflow | |||||
| """ | |||||
| # Check permission | |||||
| if not current_user.is_editor: | |||||
| raise Forbidden() | |||||
| if not isinstance(current_user, Account): | |||||
| raise Forbidden() | |||||
| workflow_service = WorkflowService() | |||||
| # Create a session and manage the transaction | |||||
| with Session(db.engine) as session: | |||||
| try: | |||||
| workflow_service.delete_workflow( | |||||
| session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id | |||||
| ) | |||||
| # Commit the transaction in the controller | |||||
| session.commit() | |||||
| except WorkflowInUseError as e: | |||||
| abort(400, description=str(e)) | |||||
| except DraftWorkflowDeletionError as e: | |||||
| abort(400, description=str(e)) | |||||
| except ValueError as e: | |||||
| raise NotFound(str(e)) | |||||
| return None, 204 | |||||
| api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft") | |||||
| api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config") | |||||
| api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run") | |||||
| api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run") | |||||
| api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop") | |||||
| api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run") | |||||
| api.add_resource( | |||||
| DraftWorkflowApi, | |||||
| "/apps/<uuid:app_id>/workflows/draft", | |||||
| ) | |||||
| api.add_resource( | |||||
| WorkflowConfigApi, | |||||
| "/apps/<uuid:app_id>/workflows/draft/config", | |||||
| ) | |||||
| api.add_resource( | |||||
| AdvancedChatDraftWorkflowRunApi, | |||||
| "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run", | |||||
| ) | |||||
| api.add_resource( | |||||
| DraftWorkflowRunApi, | |||||
| "/apps/<uuid:app_id>/workflows/draft/run", | |||||
| ) | |||||
| api.add_resource( | |||||
| WorkflowTaskStopApi, | |||||
| "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop", | |||||
| ) | |||||
| api.add_resource( | |||||
| DraftWorkflowNodeRunApi, | |||||
| "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run", | |||||
| ) | |||||
| api.add_resource( | api.add_resource( | ||||
| AdvancedChatDraftRunIterationNodeApi, | AdvancedChatDraftRunIterationNodeApi, | ||||
| "/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run", | "/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run", | ||||
| ) | ) | ||||
| api.add_resource( | api.add_resource( | ||||
| WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run" | |||||
| WorkflowDraftRunIterationNodeApi, | |||||
| "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run", | |||||
| ) | ) | ||||
| api.add_resource( | api.add_resource( | ||||
| AdvancedChatDraftRunLoopNodeApi, | AdvancedChatDraftRunLoopNodeApi, | ||||
| "/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run", | "/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run", | ||||
| ) | ) | ||||
| api.add_resource(WorkflowDraftRunLoopNodeApi, "/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run") | |||||
| api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish") | |||||
| api.add_resource(PublishedAllWorkflowApi, "/apps/<uuid:app_id>/workflows") | |||||
| api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs") | |||||
| api.add_resource( | api.add_resource( | ||||
| DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>" | |||||
| WorkflowDraftRunLoopNodeApi, | |||||
| "/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run", | |||||
| ) | |||||
| api.add_resource( | |||||
| PublishedWorkflowApi, | |||||
| "/apps/<uuid:app_id>/workflows/publish", | |||||
| ) | |||||
| api.add_resource( | |||||
| PublishedAllWorkflowApi, | |||||
| "/apps/<uuid:app_id>/workflows", | |||||
| ) | |||||
| api.add_resource( | |||||
| DefaultBlockConfigsApi, | |||||
| "/apps/<uuid:app_id>/workflows/default-workflow-block-configs", | |||||
| ) | |||||
| api.add_resource( | |||||
| DefaultBlockConfigApi, | |||||
| "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>", | |||||
| ) | |||||
| api.add_resource( | |||||
| ConvertToWorkflowApi, | |||||
| "/apps/<uuid:app_id>/convert-to-workflow", | |||||
| ) | |||||
| api.add_resource( | |||||
| WorkflowByIdApi, | |||||
| "/apps/<uuid:app_id>/workflows/<string:workflow_id>", | |||||
| ) | ) | ||||
| api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow") |
| "graph": fields.Raw(attribute="graph_dict"), | "graph": fields.Raw(attribute="graph_dict"), | ||||
| "features": fields.Raw(attribute="features_dict"), | "features": fields.Raw(attribute="features_dict"), | ||||
| "hash": fields.String(attribute="unique_hash"), | "hash": fields.String(attribute="unique_hash"), | ||||
| "version": fields.String(attribute="version"), | |||||
| "version": fields.String, | |||||
| "marked_name": fields.String, | |||||
| "marked_comment": fields.String, | |||||
| "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), | "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), | ||||
| "created_at": TimestampField, | "created_at": TimestampField, | ||||
| "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), | "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), |
| """add marked_name and marked_comment in workflows | |||||
| Revision ID: ee79d9b1c156 | |||||
| Revises: 4413929e1ec2 | |||||
| Create Date: 2025-03-03 14:36:05.750346 | |||||
| """ | |||||
| from alembic import op | |||||
| import models as models | |||||
| import sqlalchemy as sa | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = 'ee79d9b1c156' | |||||
| down_revision = '5511c782ee4c' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| with op.batch_alter_table('workflows', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('marked_name', sa.String(), nullable=False, server_default='')) | |||||
| batch_op.add_column(sa.Column('marked_comment', sa.String(), nullable=False, server_default='')) | |||||
| def downgrade(): | |||||
| with op.batch_alter_table('workflows', schema=None) as batch_op: | |||||
| batch_op.drop_column('marked_comment') | |||||
| batch_op.drop_column('marked_name') |
| from collections.abc import Mapping, Sequence | from collections.abc import Mapping, Sequence | ||||
| from datetime import UTC, datetime | from datetime import UTC, datetime | ||||
| from enum import Enum | from enum import Enum | ||||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||||
| from typing import TYPE_CHECKING, Any, Optional, Self, Union | |||||
| from uuid import uuid4 | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from models.model import AppMode | from models.model import AppMode | ||||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | ||||
| app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | ||||
| type: Mapped[str] = mapped_column(db.String(255), nullable=False) | type: Mapped[str] = mapped_column(db.String(255), nullable=False) | ||||
| version: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| version: Mapped[str] | |||||
| marked_name: Mapped[str] = mapped_column(default="", server_default="") | |||||
| marked_comment: Mapped[str] = mapped_column(default="", server_default="") | |||||
| graph: Mapped[str] = mapped_column(sa.Text) | graph: Mapped[str] = mapped_column(sa.Text) | ||||
| _features: Mapped[str] = mapped_column("features", sa.TEXT) | _features: Mapped[str] = mapped_column("features", sa.TEXT) | ||||
| created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) | created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) | ||||
| "conversation_variables", db.Text, nullable=False, server_default="{}" | "conversation_variables", db.Text, nullable=False, server_default="{}" | ||||
| ) | ) | ||||
| def __init__( | |||||
| self, | |||||
| @classmethod | |||||
| def new( | |||||
| cls, | |||||
| *, | *, | ||||
| tenant_id: str, | tenant_id: str, | ||||
| app_id: str, | app_id: str, | ||||
| created_by: str, | created_by: str, | ||||
| environment_variables: Sequence[Variable], | environment_variables: Sequence[Variable], | ||||
| conversation_variables: Sequence[Variable], | conversation_variables: Sequence[Variable], | ||||
| ): | |||||
| self.tenant_id = tenant_id | |||||
| self.app_id = app_id | |||||
| self.type = type | |||||
| self.version = version | |||||
| self.graph = graph | |||||
| self.features = features | |||||
| self.created_by = created_by | |||||
| self.environment_variables = environment_variables or [] | |||||
| self.conversation_variables = conversation_variables or [] | |||||
| marked_name: str = "", | |||||
| marked_comment: str = "", | |||||
| ) -> Self: | |||||
| workflow = Workflow() | |||||
| workflow.id = str(uuid4()) | |||||
| workflow.tenant_id = tenant_id | |||||
| workflow.app_id = app_id | |||||
| workflow.type = type | |||||
| workflow.version = version | |||||
| workflow.graph = graph | |||||
| workflow.features = features | |||||
| workflow.created_by = created_by | |||||
| workflow.environment_variables = environment_variables or [] | |||||
| workflow.conversation_variables = conversation_variables or [] | |||||
| workflow.marked_name = marked_name | |||||
| workflow.marked_comment = marked_comment | |||||
| workflow.created_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| workflow.updated_at = workflow.created_at | |||||
| return workflow | |||||
| @property | @property | ||||
| def created_by_account(self): | def created_by_account(self): |
| pytest-env = "~1.1.3" | pytest-env = "~1.1.3" | ||||
| pytest-mock = "~3.14.0" | pytest-mock = "~3.14.0" | ||||
| types-beautifulsoup4 = "~4.12.0.20241020" | types-beautifulsoup4 = "~4.12.0.20241020" | ||||
| types-deprecated = "~1.2.15.20250304" | |||||
| types-flask-cors = "~5.0.0.20240902" | types-flask-cors = "~5.0.0.20240902" | ||||
| types-flask-migrate = "~4.1.0.20250112" | types-flask-migrate = "~4.1.0.20250112" | ||||
| types-html5lib = "~1.1.11.20241018" | types-html5lib = "~1.1.11.20241018" |
| class WorkflowInUseError(ValueError): | |||||
| """Raised when attempting to delete a workflow that's in use by an app""" | |||||
| pass | |||||
| class DraftWorkflowDeletionError(ValueError): | |||||
| """Raised when attempting to delete a draft workflow""" | |||||
| pass |
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| from sqlalchemy import desc | |||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session | |||||
| from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager | from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager | ||||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | ||||
| from services.errors.app import WorkflowHashNotEqualError | from services.errors.app import WorkflowHashNotEqualError | ||||
| from services.workflow.workflow_converter import WorkflowConverter | from services.workflow.workflow_converter import WorkflowConverter | ||||
| from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError | |||||
| class WorkflowService: | class WorkflowService: | ||||
| """ | """ | ||||
| return workflow | return workflow | ||||
| def get_all_published_workflow(self, app_model: App, page: int, limit: int) -> tuple[list[Workflow], bool]: | |||||
| def get_all_published_workflow( | |||||
| self, | |||||
| *, | |||||
| session: Session, | |||||
| app_model: App, | |||||
| page: int, | |||||
| limit: int, | |||||
| user_id: str | None, | |||||
| named_only: bool = False, | |||||
| ) -> tuple[Sequence[Workflow], bool]: | |||||
| """ | """ | ||||
| Get published workflow with pagination | Get published workflow with pagination | ||||
| """ | """ | ||||
| if not app_model.workflow_id: | if not app_model.workflow_id: | ||||
| return [], False | return [], False | ||||
| workflows = ( | |||||
| db.session.query(Workflow) | |||||
| .filter(Workflow.app_id == app_model.id) | |||||
| .order_by(desc(Workflow.version)) | |||||
| .offset((page - 1) * limit) | |||||
| stmt = ( | |||||
| select(Workflow) | |||||
| .where(Workflow.app_id == app_model.id) | |||||
| .order_by(Workflow.version.desc()) | |||||
| .limit(limit + 1) | .limit(limit + 1) | ||||
| .all() | |||||
| .offset((page - 1) * limit) | |||||
| ) | ) | ||||
| if user_id: | |||||
| stmt = stmt.where(Workflow.created_by == user_id) | |||||
| if named_only: | |||||
| stmt = stmt.where(Workflow.marked_name != "") | |||||
| workflows = session.scalars(stmt).all() | |||||
| has_more = len(workflows) > limit | has_more = len(workflows) > limit | ||||
| if has_more: | if has_more: | ||||
| workflows = workflows[:-1] | workflows = workflows[:-1] | ||||
| # return draft workflow | # return draft workflow | ||||
| return workflow | return workflow | ||||
| def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow: | |||||
| """ | |||||
| Publish workflow from draft | |||||
| :param app_model: App instance | |||||
| :param account: Account instance | |||||
| :param draft_workflow: Workflow instance | |||||
| """ | |||||
| if not draft_workflow: | |||||
| # fetch draft workflow by app_model | |||||
| draft_workflow = self.get_draft_workflow(app_model=app_model) | |||||
| def publish_workflow( | |||||
| self, | |||||
| *, | |||||
| session: Session, | |||||
| app_model: App, | |||||
| account: Account, | |||||
| marked_name: str = "", | |||||
| marked_comment: str = "", | |||||
| ) -> Workflow: | |||||
| draft_workflow_stmt = select(Workflow).where( | |||||
| Workflow.tenant_id == app_model.tenant_id, | |||||
| Workflow.app_id == app_model.id, | |||||
| Workflow.version == "draft", | |||||
| ) | |||||
| draft_workflow = session.scalar(draft_workflow_stmt) | |||||
| if not draft_workflow: | if not draft_workflow: | ||||
| raise ValueError("No valid workflow found.") | raise ValueError("No valid workflow found.") | ||||
| # create new workflow | # create new workflow | ||||
| workflow = Workflow( | |||||
| workflow = Workflow.new( | |||||
| tenant_id=app_model.tenant_id, | tenant_id=app_model.tenant_id, | ||||
| app_id=app_model.id, | app_id=app_model.id, | ||||
| type=draft_workflow.type, | type=draft_workflow.type, | ||||
| created_by=account.id, | created_by=account.id, | ||||
| environment_variables=draft_workflow.environment_variables, | environment_variables=draft_workflow.environment_variables, | ||||
| conversation_variables=draft_workflow.conversation_variables, | conversation_variables=draft_workflow.conversation_variables, | ||||
| marked_name=marked_name, | |||||
| marked_comment=marked_comment, | |||||
| ) | ) | ||||
| # commit db session changes | # commit db session changes | ||||
| db.session.add(workflow) | |||||
| db.session.flush() | |||||
| db.session.commit() | |||||
| app_model.workflow_id = workflow.id | |||||
| db.session.commit() | |||||
| session.add(workflow) | |||||
| # trigger app workflow events | # trigger app workflow events | ||||
| app_published_workflow_was_updated.send(app_model, published_workflow=workflow) | app_published_workflow_was_updated.send(app_model, published_workflow=workflow) | ||||
| ) | ) | ||||
| else: | else: | ||||
| raise ValueError(f"Invalid app mode: {app_model.mode}") | raise ValueError(f"Invalid app mode: {app_model.mode}") | ||||
| def update_workflow( | |||||
| self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict | |||||
| ) -> Optional[Workflow]: | |||||
| """ | |||||
| Update workflow attributes | |||||
| :param session: SQLAlchemy database session | |||||
| :param workflow_id: Workflow ID | |||||
| :param tenant_id: Tenant ID | |||||
| :param account_id: Account ID (for permission check) | |||||
| :param data: Dictionary containing fields to update | |||||
| :return: Updated workflow or None if not found | |||||
| """ | |||||
| stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) | |||||
| workflow = session.scalar(stmt) | |||||
| if not workflow: | |||||
| return None | |||||
| allowed_fields = ["marked_name", "marked_comment"] | |||||
| for field, value in data.items(): | |||||
| if field in allowed_fields: | |||||
| setattr(workflow, field, value) | |||||
| workflow.updated_by = account_id | |||||
| workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| return workflow | |||||
| def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool: | |||||
| """ | |||||
| Delete a workflow | |||||
| :param session: SQLAlchemy database session | |||||
| :param workflow_id: Workflow ID | |||||
| :param tenant_id: Tenant ID | |||||
| :return: True if successful | |||||
| :raises: ValueError if workflow not found | |||||
| :raises: WorkflowInUseError if workflow is in use | |||||
| :raises: DraftWorkflowDeletionError if workflow is a draft version | |||||
| """ | |||||
| stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) | |||||
| workflow = session.scalar(stmt) | |||||
| if not workflow: | |||||
| raise ValueError(f"Workflow with ID {workflow_id} not found") | |||||
| # Check if workflow is a draft version | |||||
| if workflow.version == "draft": | |||||
| raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") | |||||
| # Check if this workflow is currently referenced by an app | |||||
| stmt = select(App).where(App.workflow_id == workflow_id) | |||||
| app = session.scalar(stmt) | |||||
| if app: | |||||
| # Cannot delete a workflow that's currently in use by an app | |||||
| raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'") | |||||
| session.delete(workflow) | |||||
| return True |
| from unittest.mock import MagicMock | |||||
| import pytest | |||||
| from models.model import App | |||||
| from models.workflow import Workflow | |||||
| from services.workflow_service import WorkflowService | |||||
| class TestWorkflowService: | |||||
| @pytest.fixture | |||||
| def workflow_service(self): | |||||
| return WorkflowService() | |||||
| @pytest.fixture | |||||
| def mock_app(self): | |||||
| app = MagicMock(spec=App) | |||||
| app.id = "app-id-1" | |||||
| app.workflow_id = "workflow-id-1" | |||||
| app.tenant_id = "tenant-id-1" | |||||
| return app | |||||
| @pytest.fixture | |||||
| def mock_workflows(self): | |||||
| workflows = [] | |||||
| for i in range(5): | |||||
| workflow = MagicMock(spec=Workflow) | |||||
| workflow.id = f"workflow-id-{i}" | |||||
| workflow.app_id = "app-id-1" | |||||
| workflow.created_at = f"2023-01-0{5 - i}" # Descending date order | |||||
| workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2" | |||||
| workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else "" | |||||
| workflows.append(workflow) | |||||
| return workflows | |||||
| def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app): | |||||
| mock_app.workflow_id = None | |||||
| mock_session = MagicMock() | |||||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||||
| session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None | |||||
| ) | |||||
| assert workflows == [] | |||||
| assert has_more is False | |||||
| mock_session.scalars.assert_not_called() | |||||
| def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows): | |||||
| mock_session = MagicMock() | |||||
| mock_scalar_result = MagicMock() | |||||
| mock_scalar_result.all.return_value = mock_workflows[:3] | |||||
| mock_session.scalars.return_value = mock_scalar_result | |||||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||||
| session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None | |||||
| ) | |||||
| assert workflows == mock_workflows[:3] | |||||
| assert has_more is False | |||||
| mock_session.scalars.assert_called_once() | |||||
| def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows): | |||||
| mock_session = MagicMock() | |||||
| mock_scalar_result = MagicMock() | |||||
| # Return 4 items when limit is 3, which should indicate has_more=True | |||||
| mock_scalar_result.all.return_value = mock_workflows[:4] | |||||
| mock_session.scalars.return_value = mock_scalar_result | |||||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||||
| session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None | |||||
| ) | |||||
| # Should return only the first 3 items | |||||
| assert len(workflows) == 3 | |||||
| assert workflows == mock_workflows[:3] | |||||
| assert has_more is True | |||||
| # Test page 2 | |||||
| mock_scalar_result.all.return_value = mock_workflows[3:] | |||||
| mock_session.scalars.return_value = mock_scalar_result | |||||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||||
| session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None | |||||
| ) | |||||
| assert len(workflows) == 2 | |||||
| assert has_more is False | |||||
| def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows): | |||||
| mock_session = MagicMock() | |||||
| mock_scalar_result = MagicMock() | |||||
| # Filter workflows for user-id-1 | |||||
| filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"] | |||||
| mock_scalar_result.all.return_value = filtered_workflows | |||||
| mock_session.scalars.return_value = mock_scalar_result | |||||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||||
| session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1" | |||||
| ) | |||||
| assert workflows == filtered_workflows | |||||
| assert has_more is False | |||||
| mock_session.scalars.assert_called_once() | |||||
| # Verify that the select contains a user filter clause | |||||
| args = mock_session.scalars.call_args[0][0] | |||||
| assert "created_by" in str(args) | |||||
| def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows): | |||||
| mock_session = MagicMock() | |||||
| mock_scalar_result = MagicMock() | |||||
| # Filter workflows that have a marked_name | |||||
| named_workflows = [w for w in mock_workflows if w.marked_name] | |||||
| mock_scalar_result.all.return_value = named_workflows | |||||
| mock_session.scalars.return_value = mock_scalar_result | |||||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||||
| session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True | |||||
| ) | |||||
| assert workflows == named_workflows | |||||
| assert has_more is False | |||||
| mock_session.scalars.assert_called_once() | |||||
| # Verify that the select contains a named_only filter clause | |||||
| args = mock_session.scalars.call_args[0][0] | |||||
| assert "marked_name !=" in str(args) | |||||
| def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows): | |||||
| mock_session = MagicMock() | |||||
| mock_scalar_result = MagicMock() | |||||
| # Combined filter: user-id-1 and has marked_name | |||||
| filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name] | |||||
| mock_scalar_result.all.return_value = filtered_workflows | |||||
| mock_session.scalars.return_value = mock_scalar_result | |||||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||||
| session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True | |||||
| ) | |||||
| assert workflows == filtered_workflows | |||||
| assert has_more is False | |||||
| mock_session.scalars.assert_called_once() | |||||
| # Verify that both filters are applied | |||||
| args = mock_session.scalars.call_args[0][0] | |||||
| assert "created_by" in str(args) | |||||
| assert "marked_name !=" in str(args) | |||||
| def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app): | |||||
| mock_session = MagicMock() | |||||
| mock_scalar_result = MagicMock() | |||||
| mock_scalar_result.all.return_value = [] | |||||
| mock_session.scalars.return_value = mock_scalar_result | |||||
| workflows, has_more = workflow_service.get_all_published_workflow( | |||||
| session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None | |||||
| ) | |||||
| assert workflows == [] | |||||
| assert has_more is False | |||||
| mock_session.scalars.assert_called_once() |