Selaa lähdekoodia

feat(workflow_service): workflow version control api. (#14860)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/1.0.1
-LAN- 7 kuukautta sitten
vanhempi
commit
3254018ddb
No account linked to committer's email address

+ 207
- 20
api/controllers/console/app/workflow.py Näytä tiedosto

import json import json
import logging import logging
from typing import cast


from flask import abort, request from flask import abort, request
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound


import services import services
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from factories import variable_factory from factories import variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields from fields.workflow_run_fields import workflow_run_node_execution_fields
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()


parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
args = parser.parse_args()

# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")

workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
with Session(db.engine) as session:
workflow = workflow_service.publish_workflow(
session=session,
app_model=app_model,
account=current_user,
marked_name=args.marked_name or "",
marked_comment=args.marked_comment or "",
)

app_model.workflow_id = workflow.id
db.session.commit()

workflow_created_at = TimestampField().format(workflow.created_at)


return {"result": "success", "created_at": TimestampField().format(workflow.created_at)}
session.commit()

return {
"result": "success",
"created_at": workflow_created_at,
}




class DefaultBlockConfigsApi(Resource): class DefaultBlockConfigsApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("user_id", type=str, required=False, location="args")
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
args = parser.parse_args()
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
user_id = args.get("user_id")
named_only = args.get("named_only", False)

if user_id:
if user_id != current_user.id:
raise Forbidden()
user_id = cast(str, user_id)

workflow_service = WorkflowService()
with Session(db.engine) as session:
workflows, has_more = workflow_service.get_all_published_workflow(
session=session,
app_model=app_model,
page=page,
limit=limit,
user_id=user_id,
named_only=named_only,
)

return {
"items": workflows,
"page": page,
"limit": limit,
"has_more": has_more,
}


class WorkflowByIdApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
def patch(self, app_model: App, workflow_id: str):
"""
Update workflow attributes
"""
# Check permission
if not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
raise Forbidden()

parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
page = args.get("page")
limit = args.get("limit")

# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
args = parser.parse_args()

# Prepare update data
update_data = {}
if args.get("marked_name") is not None:
update_data["marked_name"] = args["marked_name"]
if args.get("marked_comment") is not None:
update_data["marked_comment"] = args["marked_comment"]

if not update_data:
return {"message": "No valid fields to update"}, 400

workflow_service = WorkflowService() workflow_service = WorkflowService()
workflows, has_more = workflow_service.get_all_published_workflow(app_model=app_model, page=page, limit=limit)


return {"items": workflows, "page": page, "limit": limit, "has_more": has_more}
# Create a session and manage the transaction
with Session(db.engine, expire_on_commit=False) as session:
workflow = workflow_service.update_workflow(
session=session,
workflow_id=workflow_id,
tenant_id=app_model.tenant_id,
account_id=current_user.id,
data=update_data,
)

if not workflow:
raise NotFound("Workflow not found")

# Commit the transaction in the controller
session.commit()

return workflow

@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def delete(self, app_model: App, workflow_id: str):
"""
Delete workflow
"""
# Check permission
if not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
raise Forbidden()

workflow_service = WorkflowService()

# Create a session and manage the transaction
with Session(db.engine) as session:
try:
workflow_service.delete_workflow(
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
)
# Commit the transaction in the controller
session.commit()
except WorkflowInUseError as e:
abort(400, description=str(e))
except DraftWorkflowDeletionError as e:
abort(400, description=str(e))
except ValueError as e:
raise NotFound(str(e))

return None, 204




api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config")
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
api.add_resource(
DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft",
)
api.add_resource(
WorkflowConfigApi,
"/apps/<uuid:app_id>/workflows/draft/config",
)
api.add_resource(
AdvancedChatDraftWorkflowRunApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
)
api.add_resource(
DraftWorkflowRunApi,
"/apps/<uuid:app_id>/workflows/draft/run",
)
api.add_resource(
WorkflowTaskStopApi,
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
)
api.add_resource(
DraftWorkflowNodeRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource( api.add_resource(
AdvancedChatDraftRunIterationNodeApi, AdvancedChatDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run", "/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
) )
api.add_resource( api.add_resource(
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
WorkflowDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
) )
api.add_resource( api.add_resource(
AdvancedChatDraftRunLoopNodeApi, AdvancedChatDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run", "/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
) )
api.add_resource(WorkflowDraftRunLoopNodeApi, "/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run")
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
api.add_resource(PublishedAllWorkflowApi, "/apps/<uuid:app_id>/workflows")
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
api.add_resource( api.add_resource(
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
WorkflowDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
PublishedWorkflowApi,
"/apps/<uuid:app_id>/workflows/publish",
)
api.add_resource(
PublishedAllWorkflowApi,
"/apps/<uuid:app_id>/workflows",
)
api.add_resource(
DefaultBlockConfigsApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
)
api.add_resource(
DefaultBlockConfigApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
)
api.add_resource(
ConvertToWorkflowApi,
"/apps/<uuid:app_id>/convert-to-workflow",
)
api.add_resource(
WorkflowByIdApi,
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
) )
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")

+ 3
- 1
api/fields/workflow_fields.py Näytä tiedosto

"graph": fields.Raw(attribute="graph_dict"), "graph": fields.Raw(attribute="graph_dict"),
"features": fields.Raw(attribute="features_dict"), "features": fields.Raw(attribute="features_dict"),
"hash": fields.String(attribute="unique_hash"), "hash": fields.String(attribute="unique_hash"),
"version": fields.String(attribute="version"),
"version": fields.String,
"marked_name": fields.String,
"marked_comment": fields.String,
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"),
"created_at": TimestampField, "created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),

+ 29
- 0
api/migrations/versions/2025_03_03_1436-ee79d9b1c156_add_marked_name_and_marked_comment_in_.py Näytä tiedosto

"""add marked_name and marked_comment in workflows

Revision ID: ee79d9b1c156
Revises: 4413929e1ec2
Create Date: 2025-03-03 14:36:05.750346

"""
from alembic import op
import models as models
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'ee79d9b1c156'
down_revision = '5511c782ee4c'
branch_labels = None
depends_on = None


def upgrade():
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.add_column(sa.Column('marked_name', sa.String(), nullable=False, server_default=''))
batch_op.add_column(sa.Column('marked_comment', sa.String(), nullable=False, server_default=''))


def downgrade():
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.drop_column('marked_comment')
batch_op.drop_column('marked_name')

+ 27
- 14
api/models/workflow.py Näytä tiedosto

from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Self, Union
from uuid import uuid4


if TYPE_CHECKING: if TYPE_CHECKING:
from models.model import AppMode from models.model import AppMode
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(db.String(255), nullable=False) type: Mapped[str] = mapped_column(db.String(255), nullable=False)
version: Mapped[str] = mapped_column(db.String(255), nullable=False)
version: Mapped[str]
marked_name: Mapped[str] = mapped_column(default="", server_default="")
marked_comment: Mapped[str] = mapped_column(default="", server_default="")
graph: Mapped[str] = mapped_column(sa.Text) graph: Mapped[str] = mapped_column(sa.Text)
_features: Mapped[str] = mapped_column("features", sa.TEXT) _features: Mapped[str] = mapped_column("features", sa.TEXT)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
"conversation_variables", db.Text, nullable=False, server_default="{}" "conversation_variables", db.Text, nullable=False, server_default="{}"
) )


def __init__(
self,
@classmethod
def new(
cls,
*, *,
tenant_id: str, tenant_id: str,
app_id: str, app_id: str,
created_by: str, created_by: str,
environment_variables: Sequence[Variable], environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[Variable],
):
self.tenant_id = tenant_id
self.app_id = app_id
self.type = type
self.version = version
self.graph = graph
self.features = features
self.created_by = created_by
self.environment_variables = environment_variables or []
self.conversation_variables = conversation_variables or []
marked_name: str = "",
marked_comment: str = "",
) -> Self:
workflow = Workflow()
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
workflow.app_id = app_id
workflow.type = type
workflow.version = version
workflow.graph = graph
workflow.features = features
workflow.created_by = created_by
workflow.environment_variables = environment_variables or []
workflow.conversation_variables = conversation_variables or []
workflow.marked_name = marked_name
workflow.marked_comment = marked_comment
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow.updated_at = workflow.created_at
return workflow


@property @property
def created_by_account(self): def created_by_account(self):

+ 1195
- 833
api/poetry.lock
File diff suppressed because it is too large
Näytä tiedosto


+ 1
- 0
api/pyproject.toml Näytä tiedosto

pytest-env = "~1.1.3" pytest-env = "~1.1.3"
pytest-mock = "~3.14.0" pytest-mock = "~3.14.0"
types-beautifulsoup4 = "~4.12.0.20241020" types-beautifulsoup4 = "~4.12.0.20241020"
types-deprecated = "~1.2.15.20250304"
types-flask-cors = "~5.0.0.20240902" types-flask-cors = "~5.0.0.20240902"
types-flask-migrate = "~4.1.0.20250112" types-flask-migrate = "~4.1.0.20250112"
types-html5lib = "~1.1.11.20241018" types-html5lib = "~1.1.11.20241018"

+ 10
- 0
api/services/errors/workflow_service.py Näytä tiedosto

class WorkflowInUseError(ValueError):
"""Raised when attempting to delete a workflow that's in use by an app"""

pass


class DraftWorkflowDeletionError(ValueError):
"""Raised when attempting to delete a draft workflow"""

pass

+ 108
- 27
api/services/workflow_service.py Näytä tiedosto

from typing import Any, Optional from typing import Any, Optional
from uuid import uuid4 from uuid import uuid4


from sqlalchemy import desc
from sqlalchemy import select
from sqlalchemy.orm import Session


from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter from services.workflow.workflow_converter import WorkflowConverter


from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError



class WorkflowService: class WorkflowService:
""" """


return workflow return workflow


def get_all_published_workflow(self, app_model: App, page: int, limit: int) -> tuple[list[Workflow], bool]:
def get_all_published_workflow(
self,
*,
session: Session,
app_model: App,
page: int,
limit: int,
user_id: str | None,
named_only: bool = False,
) -> tuple[Sequence[Workflow], bool]:
""" """
Get published workflow with pagination Get published workflow with pagination
""" """
if not app_model.workflow_id: if not app_model.workflow_id:
return [], False return [], False


workflows = (
db.session.query(Workflow)
.filter(Workflow.app_id == app_model.id)
.order_by(desc(Workflow.version))
.offset((page - 1) * limit)
stmt = (
select(Workflow)
.where(Workflow.app_id == app_model.id)
.order_by(Workflow.version.desc())
.limit(limit + 1) .limit(limit + 1)
.all()
.offset((page - 1) * limit)
) )


if user_id:
stmt = stmt.where(Workflow.created_by == user_id)

if named_only:
stmt = stmt.where(Workflow.marked_name != "")

workflows = session.scalars(stmt).all()

has_more = len(workflows) > limit has_more = len(workflows) > limit
if has_more: if has_more:
workflows = workflows[:-1] workflows = workflows[:-1]
# return draft workflow # return draft workflow
return workflow return workflow


def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow:
"""
Publish workflow from draft

:param app_model: App instance
:param account: Account instance
:param draft_workflow: Workflow instance
"""
if not draft_workflow:
# fetch draft workflow by app_model
draft_workflow = self.get_draft_workflow(app_model=app_model)

def publish_workflow(
self,
*,
session: Session,
app_model: App,
account: Account,
marked_name: str = "",
marked_comment: str = "",
) -> Workflow:
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == "draft",
)
draft_workflow = session.scalar(draft_workflow_stmt)
if not draft_workflow: if not draft_workflow:
raise ValueError("No valid workflow found.") raise ValueError("No valid workflow found.")


# create new workflow # create new workflow
workflow = Workflow(
workflow = Workflow.new(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
app_id=app_model.id, app_id=app_model.id,
type=draft_workflow.type, type=draft_workflow.type,
created_by=account.id, created_by=account.id,
environment_variables=draft_workflow.environment_variables, environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables, conversation_variables=draft_workflow.conversation_variables,
marked_name=marked_name,
marked_comment=marked_comment,
) )


# commit db session changes # commit db session changes
db.session.add(workflow)
db.session.flush()
db.session.commit()

app_model.workflow_id = workflow.id
db.session.commit()
session.add(workflow)


# trigger app workflow events # trigger app workflow events
app_published_workflow_was_updated.send(app_model, published_workflow=workflow) app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
) )
else: else:
raise ValueError(f"Invalid app mode: {app_model.mode}") raise ValueError(f"Invalid app mode: {app_model.mode}")

def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
"""
Update workflow attributes

:param session: SQLAlchemy database session
:param workflow_id: Workflow ID
:param tenant_id: Tenant ID
:param account_id: Account ID (for permission check)
:param data: Dictionary containing fields to update
:return: Updated workflow or None if not found
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
workflow = session.scalar(stmt)

if not workflow:
return None

allowed_fields = ["marked_name", "marked_comment"]

for field, value in data.items():
if field in allowed_fields:
setattr(workflow, field, value)

workflow.updated_by = account_id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)

return workflow

def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool:
"""
Delete a workflow

:param session: SQLAlchemy database session
:param workflow_id: Workflow ID
:param tenant_id: Tenant ID
:return: True if successful
:raises: ValueError if workflow not found
:raises: WorkflowInUseError if workflow is in use
:raises: DraftWorkflowDeletionError if workflow is a draft version
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
workflow = session.scalar(stmt)

if not workflow:
raise ValueError(f"Workflow with ID {workflow_id} not found")

# Check if workflow is a draft version
if workflow.version == "draft":
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")

# Check if this workflow is currently referenced by an app
stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(stmt)
if app:
# Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")

session.delete(workflow)
return True

+ 162
- 0
api/tests/unit_tests/services/workflow/test_workflow_service.py Näytä tiedosto

from unittest.mock import MagicMock

import pytest

from models.model import App
from models.workflow import Workflow
from services.workflow_service import WorkflowService


class TestWorkflowService:
@pytest.fixture
def workflow_service(self):
return WorkflowService()

@pytest.fixture
def mock_app(self):
app = MagicMock(spec=App)
app.id = "app-id-1"
app.workflow_id = "workflow-id-1"
app.tenant_id = "tenant-id-1"
return app

@pytest.fixture
def mock_workflows(self):
workflows = []
for i in range(5):
workflow = MagicMock(spec=Workflow)
workflow.id = f"workflow-id-{i}"
workflow.app_id = "app-id-1"
workflow.created_at = f"2023-01-0{5 - i}" # Descending date order
workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2"
workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else ""
workflows.append(workflow)
return workflows

def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app):
mock_app.workflow_id = None
mock_session = MagicMock()

workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
)

assert workflows == []
assert has_more is False
mock_session.scalars.assert_not_called()

def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
mock_scalar_result.all.return_value = mock_workflows[:3]
mock_session.scalars.return_value = mock_scalar_result

workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
)

assert workflows == mock_workflows[:3]
assert has_more is False
mock_session.scalars.assert_called_once()

def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Return 4 items when limit is 3, which should indicate has_more=True
mock_scalar_result.all.return_value = mock_workflows[:4]
mock_session.scalars.return_value = mock_scalar_result

workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
)

# Should return only the first 3 items
assert len(workflows) == 3
assert workflows == mock_workflows[:3]
assert has_more is True

# Test page 2
mock_scalar_result.all.return_value = mock_workflows[3:]
mock_session.scalars.return_value = mock_scalar_result

workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None
)

assert len(workflows) == 2
assert has_more is False

def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Filter workflows for user-id-1
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"]
mock_scalar_result.all.return_value = filtered_workflows
mock_session.scalars.return_value = mock_scalar_result

workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1"
)

assert workflows == filtered_workflows
assert has_more is False
mock_session.scalars.assert_called_once()

# Verify that the select contains a user filter clause
args = mock_session.scalars.call_args[0][0]
assert "created_by" in str(args)

def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Filter workflows that have a marked_name
named_workflows = [w for w in mock_workflows if w.marked_name]
mock_scalar_result.all.return_value = named_workflows
mock_session.scalars.return_value = mock_scalar_result

workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True
)

assert workflows == named_workflows
assert has_more is False
mock_session.scalars.assert_called_once()

# Verify that the select contains a named_only filter clause
args = mock_session.scalars.call_args[0][0]
assert "marked_name !=" in str(args)

def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Combined filter: user-id-1 and has marked_name
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name]
mock_scalar_result.all.return_value = filtered_workflows
mock_session.scalars.return_value = mock_scalar_result

workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True
)

assert workflows == filtered_workflows
assert has_more is False
mock_session.scalars.assert_called_once()

# Verify that both filters are applied
args = mock_session.scalars.call_args[0][0]
assert "created_by" in str(args)
assert "marked_name !=" in str(args)

def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
mock_scalar_result.all.return_value = []
mock_session.scalars.return_value = mock_scalar_result

workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
)

assert workflows == []
assert has_more is False
mock_session.scalars.assert_called_once()

Loading…
Peruuta
Tallenna