Pārlūkot izejas kodu

feat: Persist Variables for Enhanced Debugging Workflow (#20699)

This pull request introduces a feature aimed at improving the debugging experience during workflow editing. With the addition of variable persistence, the system will automatically retain the output variables from previously executed nodes. These persisted variables can then be reused when debugging subsequent nodes, eliminating the need for repetitive manual input.

By streamlining this aspect of the workflow, the feature minimizes user errors and significantly reduces debugging effort, offering a smoother and more efficient experience.

Key highlights of this change:

- Automatic persistence of output variables for executed nodes.
- Reuse of persisted variables to simplify input steps for nodes requiring them (e.g., `code`, `template`, `variable_assigner`).
- Enhanced debugging experience with reduced friction.

Closes #19735.
tags/1.5.0
QuantumGhost pirms 4 mēnešiem
vecāks
revīzija
10b738a296
Revīzijas autora e-pasta adrese nav piesaistīta nevienam kontam
100 mainītis faili ar 4786 papildinājumiem un 759 dzēšanām
  1. 6
    0
      .github/workflows/api-tests.yml
  2. 6
    0
      .github/workflows/vdb-tests.yml
  3. 51
    52
      api/.ruff.toml
  4. 1
    0
      api/controllers/console/__init__.py
  5. 66
    4
      api/controllers/console/app/workflow.py
  6. 421
    0
      api/controllers/console/app/workflow_draft_variable.py
  7. 10
    5
      api/controllers/console/app/wraps.py
  8. 10
    0
      api/controllers/web/error.py
  9. 1
    0
      api/core/app/app_config/entities.py
  10. 27
    0
      api/core/app/apps/advanced_chat/app_generator.py
  11. 6
    2
      api/core/app/apps/advanced_chat/app_runner.py
  12. 5
    0
      api/core/app/apps/agent_chat/app_generator.py
  13. 5
    0
      api/core/app/apps/chat/app_generator.py
  14. 11
    5
      api/core/app/apps/common/workflow_response_converter.py
  15. 5
    0
      api/core/app/apps/completion/app_generator.py
  16. 27
    1
      api/core/app/apps/workflow/app_generator.py
  17. 6
    1
      api/core/app/apps/workflow/app_runner.py
  18. 56
    1
      api/core/app/apps/workflow_app_runner.py
  19. 15
    0
      api/core/app/entities/app_invoke_entities.py
  20. 10
    0
      api/core/file/constants.py
  21. 6
    1
      api/core/repositories/sqlalchemy_workflow_execution_repository.py
  22. 13
    3
      api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
  23. 14
    0
      api/core/variables/segments.py
  24. 14
    0
      api/core/variables/types.py
  25. 18
    0
      api/core/variables/utils.py
  26. 39
    0
      api/core/workflow/conversation_variable_updater.py
  27. 9
    28
      api/core/workflow/entities/variable_pool.py
  28. 2
    0
      api/core/workflow/graph_engine/entities/event.py
  29. 9
    0
      api/core/workflow/graph_engine/graph_engine.py
  30. 9
    2
      api/core/workflow/nodes/answer/answer_node.py
  31. 2
    0
      api/core/workflow/nodes/answer/answer_stream_processor.py
  32. 46
    5
      api/core/workflow/nodes/base/node.py
  33. 7
    0
      api/core/workflow/nodes/code/code_node.py
  34. 6
    2
      api/core/workflow/nodes/document_extractor/node.py
  35. 4
    0
      api/core/workflow/nodes/end/end_node.py
  36. 1
    0
      api/core/workflow/nodes/end/end_stream_processor.py
  37. 9
    4
      api/core/workflow/nodes/http_request/node.py
  38. 22
    1
      api/core/workflow/nodes/if_else/if_else_node.py
  39. 16
    2
      api/core/workflow/nodes/iteration/iteration_node.py
  40. 4
    0
      api/core/workflow/nodes/iteration/iteration_start_node.py
  41. 6
    2
      api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
  42. 11
    2
      api/core/workflow/nodes/list_operator/node.py
  43. 0
    3
      api/core/workflow/nodes/llm/file_saver.py
  44. 5
    1
      api/core/workflow/nodes/llm/node.py
  45. 4
    0
      api/core/workflow/nodes/loop/loop_end_node.py
  46. 11
    0
      api/core/workflow/nodes/loop/loop_node.py
  47. 4
    0
      api/core/workflow/nodes/loop/loop_start_node.py
  48. 5
    0
      api/core/workflow/nodes/node_mapping.py
  49. 17
    0
      api/core/workflow/nodes/parameter_extractor/entities.py
  50. 19
    9
      api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
  51. 6
    1
      api/core/workflow/nodes/start/start_node.py
  52. 4
    0
      api/core/workflow/nodes/template_transform/template_transform_node.py
  53. 7
    2
      api/core/workflow/nodes/tool/tool_node.py
  54. 10
    3
      api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
  55. 51
    15
      api/core/workflow/nodes/variable_assigner/common/helpers.py
  56. 38
    0
      api/core/workflow/nodes/variable_assigner/common/impl.py
  57. 71
    3
      api/core/workflow/nodes/variable_assigner/v1/node.py
  58. 6
    0
      api/core/workflow/nodes/variable_assigner/v2/entities.py
  59. 5
    0
      api/core/workflow/nodes/variable_assigner/v2/exc.py
  60. 64
    5
      api/core/workflow/nodes/variable_assigner/v2/node.py
  61. 79
    0
      api/core/workflow/variable_loader.py
  62. 7
    7
      api/core/workflow/workflow_cycle_manager.py
  63. 43
    21
      api/core/workflow/workflow_entry.py
  64. 49
    0
      api/core/workflow/workflow_type_encoder.py
  65. 75
    0
      api/factories/file_factory.py
  66. 78
    0
      api/factories/variable_factory.py
  67. 22
    0
      api/libs/datetime_utils.py
  68. 11
    0
      api/libs/jsonutil.py
  69. 20
    0
      api/models/_workflow_exc.py
  70. 15
    0
      api/models/model.py
  71. 229
    9
      api/models/workflow.py
  72. 1
    0
      api/pyproject.toml
  73. 3
    0
      api/services/app_dsl_service.py
  74. 4
    0
      api/services/errors/app.py
  75. 721
    0
      api/services/workflow_draft_variable_service.py
  76. 225
    17
      api/services/workflow_service.py
  77. 209
    99
      api/tests/integration_tests/.env.example
  78. 80
    8
      api/tests/integration_tests/conftest.py
  79. 0
    25
      api/tests/integration_tests/controllers/app_fixture.py
  80. 0
    0
      api/tests/integration_tests/controllers/console/__init__.py
  81. 0
    0
      api/tests/integration_tests/controllers/console/app/__init__.py
  82. 47
    0
      api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py
  83. 0
    9
      api/tests/integration_tests/controllers/test_controllers.py
  84. 0
    0
      api/tests/integration_tests/factories/__init__.py
  85. 371
    0
      api/tests/integration_tests/factories/test_storage_key_loader.py
  86. 0
    0
      api/tests/integration_tests/services/__init__.py
  87. 501
    0
      api/tests/integration_tests/services/test_workflow_draft_variable_service.py
  88. 179
    198
      api/tests/integration_tests/workflow/nodes/test_llm.py
  89. 302
    0
      api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
  90. 0
    165
      api/tests/unit_tests/core/app/segments/test_factory.py
  91. 25
    0
      api/tests/unit_tests/core/file/test_models.py
  92. 0
    0
      api/tests/unit_tests/core/variables/test_segment.py
  93. 0
    0
      api/tests/unit_tests/core/variables/test_variables.py
  94. 7
    6
      api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py
  95. 16
    3
      api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
  96. 1
    1
      api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
  97. 54
    12
      api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py
  98. 39
    0
      api/tests/unit_tests/core/workflow/test_variable_pool.py
  99. 34
    14
      api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py
  100. 0
    0
      api/tests/unit_tests/factories/test_variable_factory.py

+ 6
- 0
.github/workflows/api-tests.yml Parādīt failu

@@ -83,9 +83,15 @@ jobs:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db
redis
sandbox
ssrf_proxy

- name: setup test config
run: |
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env

- name: Run Workflow
run: uv run --project api bash dev/pytest/pytest_workflow.sh


+ 6
- 0
.github/workflows/vdb-tests.yml Parādīt failu

@@ -84,6 +84,12 @@ jobs:
elasticsearch
oceanbase

- name: setup test config
run: |
echo $(pwd)
ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env

- name: Check VDB Ready (TiDB)
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py


+ 51
- 52
api/.ruff.toml Parādīt failu

@@ -1,6 +1,4 @@
exclude = [
"migrations/*",
]
exclude = ["migrations/*"]
line-length = 120

[format]
@@ -9,14 +7,14 @@ quote-style = "double"
[lint]
preview = false
select = [
"B", # flake8-bugbear rules
"C4", # flake8-comprehensions
"E", # pycodestyle E rules
"F", # pyflakes rules
"FURB", # refurb rules
"I", # isort rules
"N", # pep8-naming
"PT", # flake8-pytest-style rules
"B", # flake8-bugbear rules
"C4", # flake8-comprehensions
"E", # pycodestyle E rules
"F", # pyflakes rules
"FURB", # refurb rules
"I", # isort rules
"N", # pep8-naming
"PT", # flake8-pytest-style rules
"PLC0208", # iteration-over-set
"PLC0414", # useless-import-alias
"PLE0604", # invalid-all-object
@@ -24,19 +22,19 @@ select = [
"PLR0402", # manual-from-import
"PLR1711", # useless-return
"PLR1714", # repeated-equality-comparison
"RUF013", # implicit-optional
"RUF019", # unnecessary-key-check
"RUF100", # unused-noqa
"RUF101", # redirected-noqa
"RUF200", # invalid-pyproject-toml
"RUF022", # unsorted-dunder-all
"S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules
"TRY400", # error-instead-of-exception
"TRY401", # verbose-log-message
"UP", # pyupgrade rules
"W191", # tab-indentation
"W605", # invalid-escape-sequence
"RUF013", # implicit-optional
"RUF019", # unnecessary-key-check
"RUF100", # unused-noqa
"RUF101", # redirected-noqa
"RUF200", # invalid-pyproject-toml
"RUF022", # unsorted-dunder-all
"S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules
"TRY400", # error-instead-of-exception
"TRY401", # verbose-log-message
"UP", # pyupgrade rules
"W191", # tab-indentation
"W605", # invalid-escape-sequence
# security related linting rules
# RCE proctection (sort of)
"S102", # exec-builtin, disallow use of `exec`
@@ -47,36 +45,37 @@ select = [
]

ignore = [
"E402", # module-import-not-at-top-of-file
"E711", # none-comparison
"E712", # true-false-comparison
"E721", # type-comparison
"E722", # bare-except
"F821", # undefined-name
"F841", # unused-variable
"E402", # module-import-not-at-top-of-file
"E711", # none-comparison
"E712", # true-false-comparison
"E721", # type-comparison
"E722", # bare-except
"F821", # undefined-name
"F841", # unused-variable
"FURB113", # repeated-append
"FURB152", # math-constant
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
"B903", # class-as-data-structure
"B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict
"N806", # non-lowercase-variable-in-function
"N815", # mixed-case-variable-in-class-scope
"PT011", # pytest-raises-too-broad
"SIM102", # collapsible-if
"SIM103", # needless-bool
"SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
"B903", # class-as-data-structure
"B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict
"N806", # non-lowercase-variable-in-function
"N815", # mixed-case-variable-in-class-scope
"PT011", # pytest-raises-too-broad
"SIM102", # collapsible-if
"SIM103", # needless-bool
"SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
]

[lint.per-file-ignores]

+ 1
- 0
api/controllers/console/__init__.py Parādīt failu

@@ -63,6 +63,7 @@ from .app import (
statistic,
workflow,
workflow_app_log,
workflow_draft_variable,
workflow_run,
workflow_statistic,
)

+ 66
- 4
api/controllers/console/app/workflow.py Parādīt failu

@@ -1,5 +1,6 @@
import json
import logging
from collections.abc import Sequence
from typing import cast

from flask import abort, request
@@ -18,10 +19,12 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from extensions.ext_database import db
from factories import variable_factory
from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
@@ -30,6 +33,7 @@ from libs.login import current_user, login_required
from models import App
from models.account import Account
from models.model import AppMode
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
@@ -38,6 +42,24 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
logger = logging.getLogger(__name__)


# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
# at the controller level rather than in the workflow logic. This would improve separation
# of concerns and make the code more maintainable.
def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]:
files = files or []

file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
file_objs: Sequence[File] = []
if file_extra_config is None:
return file_objs
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=workflow.tenant_id,
config=file_extra_config,
)
return file_objs


class DraftWorkflowApi(Resource):
@setup_required
@login_required
@@ -402,15 +424,30 @@ class DraftWorkflowNodeRunApi(Resource):

parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("query", type=str, required=False, location="json", default="")
parser.add_argument("files", type=list, location="json", default=[])
args = parser.parse_args()

inputs = args.get("inputs")
if inputs == None:
user_inputs = args.get("inputs")
if user_inputs is None:
raise ValueError("missing inputs")

workflow_srv = WorkflowService()
# fetch draft workflow by app_model
draft_workflow = workflow_srv.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError("Workflow not initialized")
files = _parse_file(draft_workflow, args.get("files"))
workflow_service = WorkflowService()

workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
app_model=app_model,
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query=args.get("query", ""),
files=files,
)

return workflow_node_execution
@@ -731,6 +768,27 @@ class WorkflowByIdApi(Resource):
return None, 204


class DraftWorkflowNodeLastRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_fields)
def get(self, app_model: App, node_id: str):
srv = WorkflowService()
workflow = srv.get_draft_workflow(app_model)
if not workflow:
raise NotFound("Workflow not found")
node_exec = srv.get_node_last_run(
app_model=app_model,
workflow=workflow,
node_id=node_id,
)
if node_exec is None:
raise NotFound("last run not found")
return node_exec


api.add_resource(
DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft",
@@ -795,3 +853,7 @@ api.add_resource(
WorkflowByIdApi,
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
)
api.add_resource(
DraftWorkflowNodeLastRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
)

+ 421
- 0
api/controllers/console/app/workflow_draft_variable.py Parādīt failu

@@ -0,0 +1,421 @@
import logging
from typing import Any, NoReturn

from flask import Response
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden

from controllers.console import api
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode, db
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService

logger = logging.getLogger(__name__)


def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value


def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)


def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser


_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}

_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
value=fields.Raw(attribute=_serialize_var_value),
)

_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda _: "env"),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}

_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
}


def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
return var_list.variables


_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
"total": fields.Raw(),
}

_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}


def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.

It ensures the following conditions are satisfied:

- Dify has been property setup.
- The request user has logged in and initialized.
- The requested app is a workflow or a chat flow.
- The request user has the edit permission for the app.
"""

@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def wrapper(*args, **kwargs):
if not current_user.is_editor:
raise Forbidden()
return f(*args, **kwargs)

return wrapper


class WorkflowVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
def get(self, app_model: App):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()

# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
if not workflow_exist:
raise DraftWorkflowNotExist()

# fetch draft workflow by app_model
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=app_model.id,
page=args.page,
limit=args.limit,
)

return workflow_vars

@_api_prerequisite
def delete(self, app_model: App):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
draft_var_srv.delete_workflow_variables(app_model.id)
db.session.commit()
return Response("", 204)


def validate_node_id(node_id: str) -> NoReturn | None:
if node_id in [
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
]:
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
# with specific `node_id` in database, we still want to make the API separated. By disallowing
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
# we mitigate the risk that user of the API depending on the implementation detail of the API.
#
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)

raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
)
return None


class NodeVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)

return node_vars

@_api_prerequisite
def delete(self, app_model: App, node_id: str):
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(app_model.id, node_id)
db.session.commit()
return Response("", 204)


class VariableApi(Resource):
_PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value"

@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def get(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable

@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def patch(self, app_model: App, variable_id: str):
# Request payload for file types:
#
# Local File:
#
# {
# "type": "image",
# "transfer_method": "local_file",
# "url": "",
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
# }
#
# Remote File:
#
#
# {
# "type": "image",
# "transfer_method": "remote_url",
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }

parser = reqparse.RequestParser()
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
# Parse 'value' field as-is to maintain its original data structure
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")

draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args = parser.parse_args(strict=True)

variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")

new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
if new_name is None and raw_value is None:
return variable

new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable

@_api_prerequisite
def delete(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)


class VariableResetApi(Resource):
@_api_prerequisite
def put(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)

workflow_srv = WorkflowService()
draft_workflow = workflow_srv.get_draft_workflow(app_model)
if draft_workflow is None:
raise NotFoundError(
f"Draft workflow not found, app_id={app_model.id}",
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")

resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()
if resetted is None:
return Response("", 204)
else:
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)


def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
if node_id == CONVERSATION_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
elif node_id == SYSTEM_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_system_variables(app_model.id)
else:
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
return draft_vars


class ConversationVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App):
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
# so their IDs can be returned to the caller.
workflow_srv = WorkflowService()
draft_workflow = workflow_srv.get_draft_workflow(app_model)
if draft_workflow is None:
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
db.session.commit()
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)


class SystemVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App):
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)


class EnvironmentVariableCollectionApi(Resource):
@_api_prerequisite
def get(self, app_model: App):
"""
Get draft workflow
"""
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model)
if workflow is None:
raise DraftWorkflowNotExist()

env_vars = workflow.environment_variables
env_vars_list = []
for v in env_vars:
env_vars_list.append(
{
"id": v.id,
"type": "env",
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.value,
"value": v.value,
# Do not track edited for env vars.
"edited": False,
"visible": True,
"editable": True,
}
)

return {"items": env_vars_list}


api.add_resource(
WorkflowVariableCollectionApi,
"/apps/<uuid:app_id>/workflows/draft/variables",
)
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")

api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")

+ 10
- 5
api/controllers/console/app/wraps.py Parādīt failu

@@ -8,6 +8,15 @@ from libs.login import current_user
from models import App, AppMode


def _load_app_model(app_id: str) -> Optional[App]:
app_model = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
return app_model


def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func):
@wraps(view_func)
@@ -20,11 +29,7 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[

del kwargs["app_id"]

app_model = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
app_model = _load_app_model(app_id)

if not app_model:
raise AppNotFoundError()

+ 10
- 0
api/controllers/web/error.py Parādīt failu

@@ -139,3 +139,13 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429


class NotFoundError(BaseHTTPException):
error_code = "not_found"
code = 404


class InvalidArgumentError(BaseHTTPException):
error_code = "invalid_param"
code = 400

+ 1
- 0
api/core/app/app_config/entities.py Parādīt failu

@@ -104,6 +104,7 @@ class VariableEntity(BaseModel):
Variable Entity.
"""

# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""

+ 27
- 0
api/core/app/apps/advanced_chat/app_generator.py Parādīt failu

@@ -29,6 +29,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
@@ -36,6 +37,7 @@ from models import Account, App, Conversation, EndUser, Message, Workflow, Workf
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService

logger = logging.getLogger(__name__)

@@ -116,6 +118,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
)

# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
@@ -261,6 +268,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)

return self._generate(
workflow=workflow,
@@ -271,6 +285,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
variable_loader=var_loader,
)

def single_loop_generate(
@@ -336,6 +351,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)

return self._generate(
workflow=workflow,
@@ -346,6 +368,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
variable_loader=var_loader,
)

def _generate(
@@ -359,6 +382,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None,
stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@@ -410,6 +434,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"conversation_id": conversation.id,
"message_id": message.id,
"context": context,
"variable_loader": variable_loader,
},
)

@@ -438,6 +463,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation_id: str,
message_id: str,
context: contextvars.Context,
variable_loader: VariableLoader,
) -> None:
"""
Generate worker in a new thread.
@@ -464,6 +490,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
)

runner.run()

+ 6
- 2
api/core/app/apps/advanced_chat/app_runner.py Parādīt failu

@@ -19,6 +19,7 @@ from core.moderation.base import ModerationError
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
@@ -40,14 +41,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation: Conversation,
message: Message,
dialogue_count: int,
variable_loader: VariableLoader,
) -> None:
super().__init__(queue_manager)

super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
self._dialogue_count = dialogue_count

def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id

def run(self) -> None:
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)

+ 5
- 0
api/core/app/apps/agent_chat/app_generator.py Parādīt failu

@@ -124,6 +124,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["retriever_resource"] = {"enabled": True}

# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:

+ 5
- 0
api/core/app/apps/chat/app_generator.py Parādīt failu

@@ -115,6 +115,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["retriever_resource"] = {"enabled": True}

# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:

+ 11
- 5
api/core/app/apps/common/workflow_response_converter.py Parādīt failu

@@ -48,6 +48,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import (
Account,
CreatorUserRole,
@@ -125,7 +126,7 @@ class WorkflowResponseConverter:
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
status=workflow_execution.status,
outputs=workflow_execution.outputs,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens,
@@ -202,6 +203,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at:
return None

json_converter = WorkflowRuntimeTypeConverter()

return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
@@ -214,7 +217,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@@ -245,6 +248,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at:
return None

json_converter = WorkflowRuntimeTypeConverter()

return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
@@ -257,7 +262,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@@ -376,6 +381,7 @@ class WorkflowResponseConverter:
workflow_execution_id: str,
event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@@ -384,7 +390,7 @@ class WorkflowResponseConverter:
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
outputs=json_converter.to_json_encodable(event.outputs),
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
@@ -463,7 +469,7 @@ class WorkflowResponseConverter:
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},

+ 5
- 0
api/core/app/apps/completion/app_generator.py Parādīt failu

@@ -101,6 +101,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
)

# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:

+ 27
- 1
api/core/app/apps/workflow/app_generator.py Parādīt failu

@@ -27,11 +27,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService

logger = logging.getLogger(__name__)

@@ -94,6 +96,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
files: Sequence[Mapping[str, Any]] = args.get("files") or []

# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings(
mappings=files,
@@ -186,6 +193,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@@ -219,6 +227,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
"queue_manager": queue_manager,
"context": context,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
},
)

@@ -303,6 +312,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)

return self._generate(
app_model=app_model,
@@ -313,6 +329,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
)

def single_loop_generate(
@@ -379,7 +396,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)

draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
app_model=app_model,
workflow=workflow,
@@ -389,6 +412,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
)

def _generate_worker(
@@ -397,6 +421,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
@@ -415,6 +440,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
)

runner.run()

+ 6
- 1
api/core/app/apps/workflow/app_runner.py Parādīt failu

@@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import (
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
@@ -30,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
@@ -37,10 +39,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id

def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id

def run(self) -> None:
"""
Run application

+ 56
- 1
api/core/app/apps/workflow_app_runner.py Parādīt failu

@@ -1,6 +1,8 @@
from collections.abc import Mapping
from typing import Any, Optional, cast

from sqlalchemy.orm import Session

from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
@@ -33,6 +35,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
BaseNodeEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
@@ -62,15 +65,23 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow
from services.workflow_draft_variable_service import (
DraftVariableSaver,
)


class WorkflowBasedAppRunner(AppRunner):
def __init__(self, queue_manager: AppQueueManager):
def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
self.queue_manager = queue_manager
self._variable_loader = variable_loader

def _get_app_id(self) -> str:
raise NotImplementedError("not implemented")

def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
@@ -173,6 +184,13 @@ class WorkflowBasedAppRunner(AppRunner):
except NotImplementedError:
variable_mapping = {}

load_into_variable_pool(
variable_loader=self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)

WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
@@ -262,6 +280,12 @@ class WorkflowBasedAppRunner(AppRunner):
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)

WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
@@ -376,6 +400,8 @@ class WorkflowBasedAppRunner(AppRunner):
in_loop_id=event.in_loop_id,
)
)
self._save_draft_var_for_event(event)

elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
@@ -438,6 +464,8 @@ class WorkflowBasedAppRunner(AppRunner):
in_loop_id=event.in_loop_id,
)
)
self._save_draft_var_for_event(event)

elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
@@ -690,3 +718,30 @@ class WorkflowBasedAppRunner(AppRunner):

def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

def _save_draft_var_for_event(self, event: BaseNodeEvent):
run_result = event.route_node_state.node_run_result
if run_result is None:
return
process_data = run_result.process_data
outputs = run_result.outputs
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
session=session,
app_id=self._get_app_id(),
node_id=event.node_id,
node_type=event.node_type,
# FIXME(QuantumGhost): rely on private state of queue_manager is not ideal.
invoke_from=self.queue_manager._invoke_from,
node_execution_id=event.id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id or None,
)
draft_var_saver.save(process_data=process_data, outputs=outputs)


def _remove_first_element_from_variable_string(key: str) -> str:
"""
Remove the first element from the prefix.
"""
prefix, remaining = key.split(".", maxsplit=1)
return remaining

+ 15
- 0
api/core/app/entities/app_invoke_entities.py Parādīt failu

@@ -17,9 +17,24 @@ class InvokeFrom(Enum):
Invoke From.
"""

# SERVICE_API indicates that this invocation is from an API call to Dify app.
#
# Description of service api in Dify docs:
# https://docs.dify.ai/en/guides/application-publishing/developing-with-apis
SERVICE_API = "service-api"

# WEB_APP indicates that this invocation is from
# the web app of the workflow (or chatflow).
#
# Description of web app in Dify docs:
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
WEB_APP = "web-app"

# EXPLORE indicates that this invocation is from
# the workflow (or chatflow) explore page.
EXPLORE = "explore"
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"

@classmethod

+ 10
- 0
api/core/file/constants.py Parādīt failu

@@ -1 +1,11 @@
from typing import Any

# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
# comparing `dify_model_identity` with constants throughout the codebase, extract
# this logic into a dedicated function. This would encapsulate the implementation
# details of how different variable types are identified.
FILE_MODEL_IDENTITY = "__dify__file__"


def maybe_file_object(o: Any) -> bool:
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY

+ 6
- 1
api/core/repositories/sqlalchemy_workflow_execution_repository.py Parādīt failu

@@ -16,6 +16,7 @@ from core.workflow.entities.workflow_execution import (
WorkflowType,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import (
Account,
CreatorUserRole,
@@ -152,7 +153,11 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
db_model.version = domain_model.workflow_version
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
db_model.outputs = (
json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs))
if domain_model.outputs
else None
)
db_model.status = domain_model.status
db_model.error = domain_model.error_message if domain_model.error_message else None
db_model.total_tokens = domain_model.total_tokens

+ 13
- 3
api/core/repositories/sqlalchemy_workflow_node_execution_repository.py Parādīt failu

@@ -19,6 +19,7 @@ from core.workflow.entities.workflow_node_execution import (
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import (
Account,
CreatorUserRole,
@@ -146,6 +147,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")

json_converter = WorkflowRuntimeTypeConverter()
db_model = WorkflowNodeExecutionModel()
db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id
@@ -160,9 +162,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
db_model.node_id = domain_model.node_id
db_model.node_type = domain_model.node_type
db_model.title = domain_model.title
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
db_model.inputs = (
json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None
)
db_model.process_data = (
json.dumps(json_converter.to_json_encodable(domain_model.process_data))
if domain_model.process_data
else None
)
db_model.outputs = (
json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None
)
db_model.status = domain_model.status
db_model.error = domain_model.error
db_model.elapsed_time = domain_model.elapsed_time

+ 14
- 0
api/core/variables/segments.py Parādīt failu

@@ -75,6 +75,20 @@ class StringSegment(Segment):
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
#
# def test_float_segment_and_nan():
# nan = float("nan")
# assert nan != nan
#
# f1 = FloatSegment(value=float("nan"))
# f2 = FloatSegment(value=float("nan"))
# assert f1 != f2
#
# f3 = FloatSegment(value=nan)
# f4 = FloatSegment(value=nan)
# assert f3 != f4


class IntegerSegment(Segment):

+ 14
- 0
api/core/variables/types.py Parādīt failu

@@ -18,3 +18,17 @@ class SegmentType(StrEnum):
NONE = "none"

GROUP = "group"

def is_array_type(self):
return self in _ARRAY_TYPES


_ARRAY_TYPES = frozenset(
[
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
]
)

+ 18
- 0
api/core/variables/utils.py Parādīt failu

@@ -1,8 +1,26 @@
import json
from collections.abc import Iterable, Sequence

from .segment_group import SegmentGroup
from .segments import ArrayFileSegment, FileSegment, Segment


def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
selectors = [node_id, name]
if paths:
selectors.extend(paths)
return selectors


class SegmentJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
elif isinstance(o, SegmentGroup):
return [self.default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
else:
super().default(o)

+ 39
- 0
api/core/workflow/conversation_variable_updater.py Parādīt failu

@@ -0,0 +1,39 @@
import abc
from typing import Protocol

from core.variables import Variable


class ConversationVariableUpdater(Protocol):
"""
ConversationVariableUpdater defines an abstraction for updating conversation variable values.

It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
conversation variables.

Implementations may choose to batch updates. If batching is used, the `flush` method
should be implemented to persist buffered changes, and `update`
should handle buffering accordingly.

Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
"""

@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable") -> None:
"""
Updates the value of the specified conversation variable in the underlying storage.

:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
"""
pass

@abc.abstractmethod
def flush(self):
"""
Flushes all pending updates to the underlying storage system.

If the implementation does not buffer updates, this method can be a no-op.
"""
pass

+ 9
- 28
api/core/workflow/entities/variable_pool.py Parādīt failu

@@ -7,12 +7,12 @@ from pydantic import BaseModel, Field

from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey
from factories import variable_factory

from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from ..enums import SystemVariableKey

VariableValue = Union[str, int, float, dict, list, File]

VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
@@ -30,9 +30,11 @@ class VariablePool(BaseModel):
# TODO: This user inputs is not used for pool.
user_inputs: Mapping[str, Any] = Field(
description="User inputs",
default_factory=dict,
)
system_variables: Mapping[SystemVariableKey, Any] = Field(
description="System variables",
default_factory=dict,
)
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
@@ -43,28 +45,7 @@ class VariablePool(BaseModel):
default_factory=list,
)

def __init__(
self,
*,
system_variables: Mapping[SystemVariableKey, Any] | None = None,
user_inputs: Mapping[str, Any] | None = None,
environment_variables: Sequence[Variable] | None = None,
conversation_variables: Sequence[Variable] | None = None,
**kwargs,
):
environment_variables = environment_variables or []
conversation_variables = conversation_variables or []
user_inputs = user_inputs or {}
system_variables = system_variables or {}

super().__init__(
system_variables=system_variables,
user_inputs=user_inputs,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
**kwargs,
)

def model_post_init(self, context: Any, /) -> None:
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
@@ -91,12 +72,12 @@ class VariablePool(BaseModel):
Returns:
None
"""
if len(selector) < 2:
if len(selector) < MIN_SELECTORS_LENGTH:
raise ValueError("Invalid selector")

if isinstance(value, Variable):
variable = value
if isinstance(value, Segment):
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
else:
segment = variable_factory.build_segment(value)
@@ -118,7 +99,7 @@ class VariablePool(BaseModel):
Raises:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
if len(selector) < MIN_SELECTORS_LENGTH:
return None

hash_key = hash(tuple(selector[1:]))

+ 2
- 0
api/core/workflow/graph_engine/entities/event.py Parādīt failu

@@ -66,6 +66,8 @@ class BaseNodeEvent(GraphEngineEvent):
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"


class NodeRunStartedEvent(BaseNodeEvent):

+ 9
- 0
api/core/workflow/graph_engine/graph_engine.py Parādīt failu

@@ -314,6 +314,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
raise e

@@ -627,6 +628,7 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy,
node_version=node_instance.version(),
)

max_retries = node_instance.node_data.retry_config.max_retries
@@ -677,6 +679,7 @@ class GraphEngine:
error=run_result.error or "Unknown error",
retry_index=retries,
start_at=retry_start_at,
node_version=node_instance.version(),
)
time.sleep(retry_interval)
break
@@ -712,6 +715,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False
else:
@@ -726,6 +730,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
@@ -786,6 +791,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False

@@ -803,6 +809,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
elif isinstance(event, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
@@ -817,6 +824,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
except GenerateTaskStoppedError:
# trigger node run failed event
@@ -833,6 +841,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
return
except Exception as e:

+ 9
- 2
api/core/workflow/nodes/answer/answer_node.py Parādīt failu

@@ -18,7 +18,11 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser

class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER
_node_type = NodeType.ANSWER

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> NodeRunResult:
"""
@@ -45,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]):
part = cast(TextGenerateRouteChunk, part)
answer += part.text

return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files})
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
)

@classmethod
def _extract_variable_selector_to_variable_mapping(

+ 2
- 0
api/core/workflow/nodes/answer/answer_stream_processor.py Parādīt failu

@@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor):
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
node_version=event.node_version,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
@@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)

self.route_position[answer_node_id] += 1

+ 46
- 5
api/core/workflow/nodes/base/node.py Parādīt failu

@@ -1,7 +1,7 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast

from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)

class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[GenericNodeData]
_node_type: NodeType
_node_type: ClassVar[NodeType]

def __init__(
self,
@@ -90,8 +90,38 @@ class BaseNode(Generic[GenericNodeData]):
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
"""Extracts references variable selectors from node configuration.

The `config` parameter represents the configuration for a specific node type and corresponds
to the `data` field in the node definition object.

The returned mapping has the following structure:

{'1747829548239.#1747829667553.result#': ['1747829667553', 'result']}

For loop and iteration nodes, the mapping may look like this:

{
"1748332301644.input_selector": ["1748332363630", "result"],
"1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"],
}

where `1748332301644` is the ID of the loop / iteration node,
and `1748332325079` is the ID of the node inside the loop or iteration node.

Here, the key consists of two parts: the current node ID (provided as the `node_id`
parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector,
enclosed in `#` symbols. These two parts are separated by a dot (`.`).

The value is a list of string representing the variable selector, where the first element is the node ID
of the referenced variable, and the second element is the variable name within that node.

The meaning of the above response is:

The node with ID `1747829548239` references the variable `result` from the node with
ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a
reference to the `result` output variable of node `1747829667553`.

:param graph_config: graph config
:param config: node config
:return:
@@ -101,9 +131,10 @@ class BaseNode(Generic[GenericNodeData]):
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")

node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
)
return data

@classmethod
def _extract_variable_selector_to_variable_mapping(
@@ -139,6 +170,16 @@ class BaseNode(Generic[GenericNodeData]):
"""
return self._node_type

@classmethod
@abstractmethod
def version(cls) -> str:
"""`node_version` returns the version of current node type."""
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
#
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
# in `api/core/workflow/nodes/__init__.py`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")

@property
def should_continue_on_error(self) -> bool:
"""judge if should continue on error

+ 7
- 0
api/core/workflow/nodes/code/code_node.py Parādīt failu

@@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]):

return code_provider.get_default_config()

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> NodeRunResult:
# Get code language
code_language = self.node_data.code_language
@@ -126,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]):
prefix: str = "",
depth: int = 1,
):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")


+ 6
- 2
api/core/workflow/nodes/document_extractor/node.py Parādīt failu

@@ -24,7 +24,7 @@ from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import FileSegment
from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@@ -45,6 +45,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
_node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR

@classmethod
def version(cls) -> str:
return "1"

def _run(self):
variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
@@ -67,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": extracted_text_list},
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
)
elif isinstance(value, File):
extracted_text = _extract_text_from_file(value)

+ 4
- 0
api/core/workflow/nodes/end/end_node.py Parādīt failu

@@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData
_node_type = NodeType.END

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> NodeRunResult:
"""
Run node

+ 1
- 0
api/core/workflow/nodes/end/end_stream_processor.py Parādīt failu

@@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)

self.route_position[end_node_id] += 1

+ 9
- 4
api/core/workflow/nodes/http_request/node.py Parādīt failu

@@ -6,6 +6,7 @@ from typing import Any, Optional
from configs import dify_config
from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@@ -60,6 +61,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
},
}

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> NodeRunResult:
process_data = {}
try:
@@ -92,7 +97,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"status_code": response.status_code,
"body": response.text if not files else "",
"body": response.text if not files.value else "",
"headers": response.headers,
"files": files,
},
@@ -166,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):

return mapping

def extract_files(self, url: str, response: Response) -> list[File]:
def extract_files(self, url: str, response: Response) -> ArrayFileSegment:
"""
Extract files from response by checking both Content-Type header and URL
"""
@@ -178,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
content_disposition_type = None

if not is_file:
return files
return ArrayFileSegment(value=[])

if parsed_content_disposition:
content_disposition_filename = parsed_content_disposition.get_filename()
@@ -211,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
)
files.append(file)

return files
return ArrayFileSegment(value=files)

+ 22
- 1
api/core/workflow/nodes/if_else/if_else_node.py Parādīt failu

@@ -1,4 +1,5 @@
from typing import Literal
from collections.abc import Mapping, Sequence
from typing import Any, Literal

from typing_extensions import deprecated

@@ -16,6 +17,10 @@ class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> NodeRunResult:
"""
Run node
@@ -87,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]):

return data

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
var_mapping: dict[str, list[str]] = {}
for case in node_data.cases or []:
for condition in case.conditions:
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
var_mapping[key] = condition.variable_selector

return var_mapping


@deprecated("This function is deprecated. You should use the new cases structure.")
def _should_not_use_old_function(

+ 16
- 2
api/core/workflow/nodes/iteration/iteration_node.py Parādīt failu

@@ -11,6 +11,7 @@ from flask import Flask, current_app

from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import (
NodeRunResult,
)
@@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment
from libs.flask_utils import preserve_flask_contexts

from .exc import (
@@ -72,6 +74,10 @@ class IterationNode(BaseNode[IterationNodeData]):
},
}

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
@@ -85,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")

if isinstance(variable, NoneVariable) or len(variable.value) == 0:
# Try our best to preserve the type informat.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": []},
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
return
@@ -231,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]):
# Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist]
output_segment = build_segment(outputs)

yield IterationRunSucceededEvent(
iteration_id=self.id,
@@ -247,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
outputs={"output": output_segment},
metadata={
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,

+ 4
- 0
api/core/workflow/nodes/iteration/iteration_start_node.py Parādīt failu

@@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]):
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> NodeRunResult:
"""
Run the node.

+ 6
- 2
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py Parādīt failu

@@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
@@ -115,9 +116,12 @@ class KnowledgeRetrievalNode(LLMNode):
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
outputs = {"result": results}
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs=outputs, # type: ignore
)

except KnowledgeRetrievalNodeError as e:

+ 11
- 2
api/core/workflow/nodes/list_operator/node.py Parādīt failu

@@ -3,6 +3,7 @@ from typing import Any, Literal, Union

from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@@ -16,6 +17,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_data_cls = ListOperatorNodeData
_node_type = NodeType.LIST_OPERATOR

@classmethod
def version(cls) -> str:
return "1"

def _run(self):
inputs: dict[str, list] = {}
process_data: dict[str, list] = {}
@@ -30,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
outputs = {"result": [], "first_record": None, "last_record": None}
if isinstance(variable, ArraySegment):
result = variable.model_copy(update={"value": []})
else:
result = ArrayAnySegment(value=[])
outputs = {"result": result, "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
@@ -71,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
variable = self._apply_slice(variable)

outputs = {
"result": variable.value,
"result": variable,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}

+ 0
- 3
api/core/workflow/nodes/llm/file_saver.py Parādīt failu

@@ -119,9 +119,6 @@ class FileSaverImpl(LLMFileSaver):
size=len(data),
related_id=tool_file.id,
url=url,
# TODO(QuantumGhost): how should I set the following key?
# What's the difference between `remote_url` and `url`?
# What's the purpose of `storage_key` and `dify_model_identity`?
storage_key=tool_file.file_key,
)


+ 5
- 1
api/core/workflow/nodes/llm/node.py Parādīt failu

@@ -138,6 +138,10 @@ class LLMNode(BaseNode[LLMNodeData]):
)
self._llm_file_saver = llm_file_saver

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
"""Process structured output if enabled"""
@@ -255,7 +259,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if structured_output:
outputs["structured_output"] = structured_output
if self._file_outputs is not None:
outputs["files"] = self._file_outputs
outputs["files"] = ArrayFileSegment(value=self._file_outputs)

yield RunCompletedEvent(
run_result=NodeRunResult(

+ 4
- 0
api/core/workflow/nodes/loop/loop_end_node.py Parādīt failu

@@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]):
_node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> NodeRunResult:
"""
Run the node.

+ 11
- 0
api/core/workflow/nodes/loop/loop_node.py Parādīt failu

@@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]):
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node."""
# Get inputs
@@ -482,6 +486,13 @@ class LoopNode(BaseNode[LoopNodeData]):

variable_mapping.update(sub_node_variable_mapping)

for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping
selector = loop_variable.value
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector

# remove variable out from loop
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids

+ 4
- 0
api/core/workflow/nodes/loop/loop_start_node.py Parādīt failu

@@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]):
_node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> NodeRunResult:
"""
Run the node.

+ 5
- 0
api/core/workflow/nodes/node_mapping.py Parādīt failu

@@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var

LATEST_VERSION = "latest"

# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
# Specifically, if you have introduced new node types, you should add them here.
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,

+ 17
- 0
api/core/workflow/nodes/parameter_extractor/entities.py Parādīt failu

@@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig


class _ParameterConfigError(Exception):
pass


class ParameterConfig(BaseModel):
"""
Parameter Config.
@@ -27,6 +31,19 @@ class ParameterConfig(BaseModel):
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
return str(value)

def is_array_type(self) -> bool:
return self.type in ("array[string]", "array[number]", "array[object]")

def element_type(self) -> Literal["string", "number", "object"]:
if self.type == "array[number]":
return "number"
elif self.type == "array[string]":
return "string"
elif self.type == "array[object]":
return "object"
else:
raise _ParameterConfigError(f"{self.type} is not array type.")


class ParameterExtractorNodeData(BaseNodeData):
"""

+ 19
- 9
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py Parādīt failu

@@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser
from factories.variable_factory import build_segment_with_type

from .entities import ParameterExtractorNodeData
from .exc import (
@@ -109,6 +111,10 @@ class ParameterExtractorNode(BaseNode):
}
}

@classmethod
def version(cls) -> str:
return "1"

def _run(self):
"""
Run the node.
@@ -584,28 +590,30 @@ class ParameterExtractorNode(BaseNode):
elif parameter.type in {"string", "select"}:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
elif parameter.type.startswith("array"):
elif parameter.is_array_type():
if isinstance(result[parameter.name], list):
nested_type = parameter.type[6:-1]
transformed_result[parameter.name] = []
nested_type = parameter.element_type()
assert nested_type is not None
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
transformed_result[parameter.name] = segment_value
for item in result[parameter.name]:
if nested_type == "number":
if isinstance(item, int | float):
transformed_result[parameter.name].append(item)
segment_value.value.append(item)
elif isinstance(item, str):
try:
if "." in item:
transformed_result[parameter.name].append(float(item))
segment_value.value.append(float(item))
else:
transformed_result[parameter.name].append(int(item))
segment_value.value.append(int(item))
except ValueError:
pass
elif nested_type == "string":
if isinstance(item, str):
transformed_result[parameter.name].append(item)
segment_value.value.append(item)
elif nested_type == "object":
if isinstance(item, dict):
transformed_result[parameter.name].append(item)
segment_value.value.append(item)

if parameter.name not in transformed_result:
if parameter.type == "number":
@@ -615,7 +623,9 @@ class ParameterExtractorNode(BaseNode):
elif parameter.type in {"string", "select"}:
transformed_result[parameter.name] = ""
elif parameter.type.startswith("array"):
transformed_result[parameter.name] = []
transformed_result[parameter.name] = build_segment_with_type(
segment_type=SegmentType(parameter.type), value=[]
)

return transformed_result


+ 6
- 1
api/core/workflow/nodes/start/start_node.py Parādīt failu

@@ -10,6 +10,10 @@ class StartNode(BaseNode[StartNodeData]):
_node_data_cls = StartNodeData
_node_type = NodeType.START

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables
@@ -18,5 +22,6 @@ class StartNode(BaseNode[StartNodeData]):
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)

return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)

+ 4
- 0
api/core/workflow/nodes/template_transform/template_transform_node.py Parādīt failu

@@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
}

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> NodeRunResult:
# Get variables
variables = {}

+ 7
- 2
api/core/workflow/nodes/tool/tool_node.py Parādīt failu

@@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
@@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]):
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> Generator:
"""
Run the tool node
@@ -300,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]):
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, File)
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
@@ -363,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": files, "json": json, **variables},
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,

+ 10
- 3
api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py Parādīt failu

@@ -1,3 +1,6 @@
from collections.abc import Mapping

from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@@ -9,16 +12,20 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR

@classmethod
def version(cls) -> str:
return "1"

def _run(self) -> NodeRunResult:
# Get variables
outputs = {}
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {}

if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {"output": variable.to_object()}
outputs = {"output": variable}

inputs = {".".join(selector[1:]): variable.to_object()}
break
@@ -28,7 +35,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
variable = self.graph_runtime_state.variable_pool.get(selector)

if variable is not None:
outputs[group.group_name] = {"output": variable.to_object()}
outputs[group.group_name] = {"output": variable}
inputs[".".join(selector[1:])] = variable.to_object()
break


+ 51
- 15
api/core/workflow/nodes/variable_assigner/common/helpers.py Parādīt failu

@@ -1,19 +1,55 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, TypeVar

from core.variables import Variable
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db
from models import ConversationVariable
from pydantic import BaseModel

from core.variables import Segment
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.types import SegmentType

def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
# Use double underscore (`__`) prefix for internal variables
# to minimize risk of collision with user-defined variable names.
_UPDATED_VARIABLES_KEY = "__updated_variables"


class UpdatedVariable(BaseModel):
name: str
selector: Sequence[str]
value_type: SegmentType
new_value: Any


_T = TypeVar("_T", bound=MutableMapping[str, Any])


def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
if len(selector) < MIN_SELECTORS_LENGTH:
raise Exception("selector too short")
node_id, var_name = selector[:2]
return UpdatedVariable(
name=var_name,
selector=list(selector[:2]),
value_type=seg.value_type,
new_value=seg.value,
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()


def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T:
m[_UPDATED_VARIABLES_KEY] = updates
return m


def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None:
updated_values = m.get(_UPDATED_VARIABLES_KEY, None)
if updated_values is None:
return None
result = []
for items in updated_values:
if isinstance(items, UpdatedVariable):
result.append(items)
elif isinstance(items, dict):
items = UpdatedVariable.model_validate(items)
result.append(items)
else:
raise TypeError(f"Invalid updated variable: {items}, type={type(items)}")
return result

+ 38
- 0
api/core/workflow/nodes/variable_assigner/common/impl.py Parādīt failu

@@ -0,0 +1,38 @@
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session

from core.variables.variables import Variable
from models.engine import db
from models.workflow import ConversationVariable

from .exc import VariableOperatorNodeError


class ConversationVariableUpdaterImpl:
_engine: Engine | None

def __init__(self, engine: Engine | None = None) -> None:
self._engine = engine

def _get_engine(self) -> Engine:
if self._engine:
return self._engine
return db.engine

def update(self, conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(self._get_engine()) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()

def flush(self):
pass


def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
return ConversationVariableUpdaterImpl()

+ 71
- 3
api/core/workflow/nodes/variable_assigner/v1/node.py Parādīt failu

@@ -1,4 +1,9 @@
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, TypeAlias

from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@@ -7,16 +12,71 @@ from core.workflow.nodes.variable_assigner.common import helpers as common_helpe
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory

from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode

if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState


_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]


class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY

def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
self._conv_var_updater_factory = conv_var_updater_factory

@classmethod
def version(cls) -> str:
return "1"

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: VariableAssignerData,
) -> Mapping[str, Sequence[str]]:
mapping = {}
assigned_variable_node_id = node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
selector_key = ".".join(node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.assigned_variable_selector

selector_key = ".".join(node_data.input_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.input_variable_selector
return mapping

def _run(self) -> NodeRunResult:
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found")

@@ -44,20 +104,28 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")

# Over write the variable.
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableOperatorNodeError("conversation_id not found")
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
conv_var_updater = self._conv_var_updater_factory()
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
conv_var_updater.flush()
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]

return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
"value": income_value.to_object(),
},
# NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`,
# we still set `output_variables` as a list to ensure the schema of output is
# compatible with `v2.VariableAssignerNode`.
process_data=common_helpers.set_updated_variables({}, updated_variables),
outputs={},
)



+ 6
- 0
api/core/workflow/nodes/variable_assigner/v2/entities.py Parādīt failu

@@ -12,6 +12,12 @@ class VariableOperationItem(BaseModel):
variable_selector: Sequence[str]
input_type: InputType
operation: Operation
# NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context:
#
# 1. For CONSTANT input_type: Contains the literal value to be used in the operation.
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
# 3. During the variable updating procedure: The `value` field is reassigned to hold
# the resolved actual value that will be applied to the target variable.
value: Any | None = None



+ 5
- 0
api/core/workflow/nodes/variable_assigner/v2/exc.py Parādīt failu

@@ -29,3 +29,8 @@ class InvalidInputValueError(VariableOperatorNodeError):
class ConversationIDNotFoundError(VariableOperatorNodeError):
def __init__(self):
super().__init__("conversation_id not found")


class InvalidDataError(VariableOperatorNodeError):
def __init__(self, message: str) -> None:
super().__init__(message)

+ 64
- 5
api/core/workflow/nodes/variable_assigner/v2/node.py Parādīt failu

@@ -1,34 +1,84 @@
import json
from collections.abc import Sequence
from typing import Any, cast
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Any, TypeAlias, cast

from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory

from . import helpers
from .constants import EMPTY_VALUE_MAPPING
from .entities import VariableAssignerNodeData
from .entities import VariableAssignerNodeData, VariableOperationItem
from .enums import InputType, Operation
from .exc import (
ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidDataError,
InvalidInputValueError,
OperationNotSupportedError,
VariableNotFoundError,
)

_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]


def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0]
if selector_node_id != CONVERSATION_VARIABLE_NODE_ID:
return
selector_str = ".".join(item.variable_selector)
key = f"{node_id}.#{selector_str}#"
mapping[key] = item.variable_selector


def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
# Keep this in sync with the logic in _run methods...
if item.input_type != InputType.VARIABLE:
return
selector = item.value
if not isinstance(selector, list):
raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}")
if len(selector) < MIN_SELECTORS_LENGTH:
raise InvalidDataError(f"selector too short, {node_id=}, {item=}")
selector_str = ".".join(selector)
key = f"{node_id}.#{selector_str}#"
mapping[key] = selector


class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER

def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()

@classmethod
def version(cls) -> str:
return "2"

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: VariableAssignerNodeData,
) -> Mapping[str, Sequence[str]]:
var_mapping: dict[str, Sequence[str]] = {}
for item in node_data.items:
_target_mapping_from_item(var_mapping, node_id, item)
_source_mapping_from_item(var_mapping, node_id, item)
return var_mapping

def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
process_data: dict[str, Any] = {}
@@ -114,6 +164,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
# remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))

conv_var_updater = self._conv_var_updater_factory()
# Update variables
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
@@ -128,15 +179,23 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
conv_var_updater.update(
conversation_id=cast(str, conversation_id),
variable=variable,
)

conv_var_updater.flush()
updated_variables = [
common_helpers.variable_to_processed_data(selector, seg)
for selector in updated_variable_selectors
if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None
]

process_data = common_helpers.set_updated_variables(process_data, updated_variables)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={},
)

def _handle_item(

+ 79
- 0
api/core/workflow/variable_loader.py Parādīt failu

@@ -0,0 +1,79 @@
import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol

from core.variables import Variable
from core.workflow.entities.variable_pool import VariablePool


class VariableLoader(Protocol):
"""Interface for loading variables based on selectors.

A `VariableLoader` is responsible for retrieving additional variables required during the execution
of a single node, which are not provided as user inputs.

NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same
application and share the same `app_id`. However, this interface does not enforce that constraint,
and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of
concern and allow for flexible implementations.

Implementations of `VariableLoader` should almost always have an `app_id` parameter in
their constructor.

TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into
`WorkflowService.single_step_run`, we may get rid of this interface.
"""

@abc.abstractmethod
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.

The order of the returned variables is not guaranteed. If the caller wants to ensure
a specific order, they should sort the returned list themselves.

:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
:return: a list of Variable objects that match the provided selectors.
"""
pass


class _DummyVariableLoader(VariableLoader):
"""A dummy implementation of VariableLoader that does not load any variables.
Serves as a placeholder when no variable loading is needed.
"""

def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
return []


DUMMY_VARIABLE_LOADER = _DummyVariableLoader()


def load_into_variable_pool(
variable_loader: VariableLoader,
variable_pool: VariablePool,
variable_mapping: Mapping[str, Sequence[str]],
user_inputs: Mapping[str, Any],
):
# Loading missing variable from draft var here, and set it into
# variable_pool.
variables_to_load: list[list[str]] = []
for key, selector in variable_mapping.items():
# NOTE(QuantumGhost): this logic needs to be in sync with
# `WorkflowEntry.mapping_user_inputs_to_variable_pool`.
node_variable_list = key.split(".")
if len(node_variable_list) < 1:
raise ValueError(f"Invalid variable key: {key}. It should have at least one element.")
if key in user_inputs:
continue
node_variable_key = ".".join(node_variable_list[1:])
if node_variable_key in user_inputs:
continue
if variable_pool.get(selector) is None:
variables_to_load.append(list(selector))
loaded = variable_loader.load_variables(variables_to_load)
for var in loaded:
variable_pool.add(var.selector, var)

+ 7
- 7
api/core/workflow/workflow_cycle_manager.py Parādīt failu

@@ -92,7 +92,7 @@ class WorkflowCycleManager:
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)

outputs = WorkflowEntry.handle_special_values(outputs)
# outputs = WorkflowEntry.handle_special_values(outputs)

workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
workflow_execution.outputs = outputs or {}
@@ -125,7 +125,7 @@ class WorkflowCycleManager:
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
# outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)

execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.outputs = outputs or {}
@@ -242,9 +242,9 @@ class WorkflowCycleManager:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")

# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
inputs = event.inputs
process_data = event.process_data
outputs = event.outputs

# Convert metadata keys to strings
execution_metadata_dict = {}
@@ -289,7 +289,7 @@ class WorkflowCycleManager:
# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
outputs = event.outputs

# Convert metadata keys to strings
execution_metadata_dict = {}
@@ -326,7 +326,7 @@ class WorkflowCycleManager:
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
outputs = event.outputs

# Convert metadata keys to strings
origin_metadata = {

+ 43
- 21
api/core/workflow/workflow_entry.py Parādīt failu

@@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from factories import file_factory
from models.enums import UserFrom
from models.workflow import (
@@ -119,7 +120,9 @@ class WorkflowEntry:
workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: dict,
user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
"""
Single step run workflow node
@@ -129,29 +132,14 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
# fetch node info from workflow graph
workflow_graph = workflow.graph_dict
if not workflow_graph:
raise ValueError("workflow graph not found")

nodes = workflow_graph.get("nodes")
if not nodes:
raise ValueError("nodes not found in workflow graph")

# fetch node config from node id
try:
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise ValueError("node id not found in workflow graph")
node_config = workflow.get_node_config_by_id(node_id)
node_config_data = node_config.get("data", {})

# Get node class
node_type = NodeType(node_config.get("data", {}).get("type"))
node_version = node_config.get("data", {}).get("version", "1")
node_type = NodeType(node_config_data.get("type"))
node_version = node_config_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]

# init variable pool
variable_pool = VariablePool(environment_variables=workflow.environment_variables)

# init graph
graph = Graph.init(graph_config=workflow.graph_dict)

@@ -182,16 +170,33 @@ class WorkflowEntry:
except NotImplementedError:
variable_mapping = {}

# Loading missing variable from draft var here, and set it into
# variable_pool.
load_into_variable_pool(
variable_loader=variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)

cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)

try:
# run node
generator = node_instance.run()
except Exception as e:
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
workflow.id,
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
return node_instance, generator

@@ -294,10 +299,20 @@ class WorkflowEntry:

return node_instance, generator
except Exception as e:
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))

@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
# NOTE(QuantumGhost): Avoid using this function in new code.
# Keep values structured as long as possible and only convert to dict
# immediately before serialization (e.g., JSON serialization) to maintain
# data integrity and type information.
result = WorkflowEntry._handle_special_values(value)
return result if isinstance(result, Mapping) or result is None else dict(result)

@@ -324,10 +339,17 @@ class WorkflowEntry:
cls,
*,
variable_mapping: Mapping[str, Sequence[str]],
user_inputs: dict,
user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
tenant_id: str,
) -> None:
# NOTE(QuantumGhost): This logic should remain synchronized with
# the implementation of `load_into_variable_pool`, specifically the logic about
# variable existence checking.

# WARNING(QuantumGhost): The semantics of this method are not clearly defined,
# and multiple parts of the codebase depend on its current behavior.
# Modify with caution.
for node_variable, variable_selector in variable_mapping.items():
# fetch node id and variable key from node_variable
node_variable_list = node_variable.split(".")

+ 49
- 0
api/core/workflow/workflow_type_encoder.py Parādīt failu

@@ -0,0 +1,49 @@
import json
from collections.abc import Mapping
from typing import Any

from pydantic import BaseModel

from core.file.models import File
from core.variables import Segment


class WorkflowRuntimeTypeEncoder(json.JSONEncoder):
def default(self, o: Any):
if isinstance(o, Segment):
return o.value
elif isinstance(o, File):
return o.to_dict()
elif isinstance(o, BaseModel):
return o.model_dump(mode="json")
else:
return super().default(o)


class WorkflowRuntimeTypeConverter:
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
result = self._to_json_encodable_recursive(value)
return result if isinstance(result, Mapping) or result is None else dict(result)

def _to_json_encodable_recursive(self, value: Any) -> Any:
if value is None:
return value
if isinstance(value, (bool, int, str, float)):
return value
if isinstance(value, Segment):
return self._to_json_encodable_recursive(value.value)
if isinstance(value, File):
return value.to_dict()
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
if isinstance(value, dict):
res = {}
for k, v in value.items():
res[k] = self._to_json_encodable_recursive(v)
return res
if isinstance(value, list):
res_list = []
for item in value:
res_list.append(self._to_json_encodable_recursive(item))
return res_list
return value

+ 75
- 0
api/factories/file_factory.py Parādīt failu

@@ -5,6 +5,7 @@ from typing import Any, cast

import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session

from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
@@ -91,6 +92,8 @@ def build_from_mappings(
tenant_id: str,
strict_type_validation: bool = False,
) -> Sequence[File]:
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
# Implement batch processing to reduce database load when handling multiple files.
files = [
build_from_mapping(
mapping=mapping,
@@ -377,3 +380,75 @@ def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:

def get_file_type_by_mime_type(mime_type: str) -> FileType:
return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM


class StorageKeyLoader:
"""FileKeyLoader load the storage key from database for a list of files.
This loader is batched, the
"""

def __init__(self, session: Session, tenant_id: str) -> None:
self._session = session
self._tenant_id = tenant_id

def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]:
stmt = select(UploadFile).where(
UploadFile.id.in_(upload_file_ids),
UploadFile.tenant_id == self._tenant_id,
)

return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}

def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]:
stmt = select(ToolFile).where(
ToolFile.id.in_(tool_file_ids),
ToolFile.tenant_id == self._tenant_id,
)
return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}

def load_storage_keys(self, files: Sequence[File]):
"""Loads storage keys for a sequence of files by retrieving the corresponding
`UploadFile` or `ToolFile` records from the database based on their transfer method.

This method doesn't modify the input sequence structure but updates the `_storage_key`
property of each file object by extracting the relevant key from its database record.

Performance note: This is a batched operation where database query count remains constant
regardless of input size. However, for optimal performance, input sequences should contain
fewer than 1000 files. For larger collections, split into smaller batches and process each
batch separately.
"""

upload_file_ids: list[uuid.UUID] = []
tool_file_ids: list[uuid.UUID] = []
for file in files:
related_model_id = file.related_id
if file.related_id is None:
raise ValueError("file id should not be None.")
if file.tenant_id != self._tenant_id:
err_msg = (
f"invalid file, expected tenant_id={self._tenant_id}, "
f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}"
)
raise ValueError(err_msg)
model_id = uuid.UUID(related_model_id)

if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_ids.append(model_id)
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_ids.append(model_id)

tool_files = self._load_tool_files(tool_file_ids)
upload_files = self._load_upload_files(upload_file_ids)
for file in files:
model_id = uuid.UUID(file.related_id)
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_row = upload_files.get(model_id)
if upload_file_row is None:
raise ValueError(...)
file._storage_key = upload_file_row.key
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_row = tool_files.get(model_id)
if tool_file_row is None:
raise ValueError(...)
file._storage_key = tool_file_row.file_key

+ 78
- 0
api/factories/variable_factory.py Parādīt failu

@@ -43,6 +43,10 @@ class UnsupportedSegmentTypeError(Exception):
pass


class TypeMismatchError(Exception):
pass


# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
StringSegment: StringVariable,
@@ -110,6 +114,10 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
return cast(Variable, result)


def infer_segment_type_from_value(value: Any, /) -> SegmentType:
return build_segment(value).value_type


def build_segment(value: Any, /) -> Segment:
if value is None:
return NoneSegment()
@@ -140,10 +148,80 @@ def build_segment(value: Any, /) -> Segment:
case SegmentType.NONE:
return ArrayAnySegment(value=value)
case _:
# This should be unreachable.
raise ValueError(f"not supported value {value}")
raise ValueError(f"not supported value {value}")


def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
"""
Build a segment with explicit type checking.

This function creates a segment from a value while enforcing type compatibility
with the specified segment_type. It provides stricter type validation compared
to the standard build_segment function.

Args:
segment_type: The expected SegmentType for the resulting segment
value: The value to be converted into a segment

Returns:
Segment: A segment instance of the appropriate type

Raises:
TypeMismatchError: If the value type doesn't match the expected segment_type

Special Cases:
- For empty list [] values, if segment_type is array[*], returns the corresponding array type
- Type validation is performed before segment creation

Examples:
>>> build_segment_with_type(SegmentType.STRING, "hello")
StringSegment(value="hello")

>>> build_segment_with_type(SegmentType.ARRAY_STRING, [])
ArrayStringSegment(value=[])

>>> build_segment_with_type(SegmentType.STRING, 123)
# Raises TypeMismatchError
"""
# Handle None values
if value is None:
if segment_type == SegmentType.NONE:
return NoneSegment()
else:
raise TypeMismatchError(f"Expected {segment_type}, but got None")

# Handle empty list special case for array types
if isinstance(value, list) and len(value) == 0:
if segment_type == SegmentType.ARRAY_ANY:
return ArrayAnySegment(value=value)
elif segment_type == SegmentType.ARRAY_STRING:
return ArrayStringSegment(value=value)
elif segment_type == SegmentType.ARRAY_NUMBER:
return ArrayNumberSegment(value=value)
elif segment_type == SegmentType.ARRAY_OBJECT:
return ArrayObjectSegment(value=value)
elif segment_type == SegmentType.ARRAY_FILE:
return ArrayFileSegment(value=value)
else:
raise TypeMismatchError(f"Expected {segment_type}, but got empty list")

# Build segment using existing logic to infer actual type
inferred_segment = build_segment(value)
inferred_type = inferred_segment.value_type

# Type compatibility checking
if inferred_type == segment_type:
return inferred_segment

# Type mismatch - raise error with descriptive message
raise TypeMismatchError(
f"Type mismatch: expected {segment_type}, but value '{value}' "
f"(type: {type(value).__name__}) corresponds to {inferred_type}"
)


def segment_to_variable(
*,
segment: Segment,

+ 22
- 0
api/libs/datetime_utils.py Parādīt failu

@@ -0,0 +1,22 @@
import abc
import datetime
from typing import Protocol


class _NowFunction(Protocol):
@abc.abstractmethod
def __call__(self, tz: datetime.timezone | None) -> datetime.datetime:
pass


# _now_func is a callable with the _NowFunction signature.
# Its sole purpose is to abstract time retrieval, enabling
# developers to mock this behavior in tests and time-dependent scenarios.
_now_func: _NowFunction = datetime.datetime.now


def naive_utc_now() -> datetime.datetime:
"""Return a naive datetime object (without timezone information)
representing current UTC time.
"""
return _now_func(datetime.UTC).replace(tzinfo=None)

+ 11
- 0
api/libs/jsonutil.py Parādīt failu

@@ -0,0 +1,11 @@
import json

from pydantic import BaseModel


class PydanticModelEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, BaseModel):
return o.model_dump()
else:
super().default(o)

+ 20
- 0
api/models/_workflow_exc.py Parādīt failu

@@ -0,0 +1,20 @@
"""All these exceptions are not meant to be caught by callers."""


class WorkflowDataError(Exception):
"""Base class for all workflow data related exceptions.

This should be used to indicate issues with workflow data integrity, such as
no `graph` configuration, missing `nodes` field in `graph` configuration, or
similar issues.
"""

pass


class NodeNotFoundError(WorkflowDataError):
"""Raised when a node with the specified ID is not found in the workflow."""

def __init__(self, node_id: str):
super().__init__(f"Node with ID '{node_id}' not found in the workflow.")
self.node_id = node_id

+ 15
- 0
api/models/model.py Parādīt failu

@@ -611,6 +611,14 @@ class InstalledApp(Base):
return tenant


class ConversationSource(StrEnum):
"""This enumeration is designed for use with `Conversation.from_source`."""

# NOTE(QuantumGhost): The enumeration members may not cover all possible cases.
API = "api"
CONSOLE = "console"


class Conversation(Base):
__tablename__ = "conversations"
__table_args__ = (
@@ -632,7 +640,14 @@ class Conversation(Base):
system_instruction = db.Column(db.Text)
system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
status = db.Column(db.String(255), nullable=False)

# The `invoke_from` records how the conversation is created.
#
# Its value corresponds to the members of `InvokeFrom`.
# (api/core/app/entities/app_invoke_entities.py)
invoke_from = db.Column(db.String(255), nullable=True)

# ref: ConversationSource.
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(StringUUID)
from_account_id = db.Column(StringUUID)

+ 229
- 9
api/models/workflow.py Parādīt failu

@@ -7,10 +7,16 @@ from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4

from flask_login import current_user
from sqlalchemy import orm

from core.file.constants import maybe_file_object
from core.file.models import File
from core.variables import utils as variable_utils
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
from core.workflow.nodes.enums import NodeType
from factories.variable_factory import TypeMismatchError, build_segment_with_type

from ._workflow_exc import NodeNotFoundError, WorkflowDataError

if TYPE_CHECKING:
from models.model import AppMode
@@ -72,6 +78,10 @@ class WorkflowType(Enum):
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT


class _InvalidGraphDefinitionError(Exception):
pass


class Workflow(Base):
"""
Workflow, for `Workflow App` and `Chat App workflow mode`.
@@ -136,6 +146,8 @@ class Workflow(Base):
"conversation_variables", db.Text, nullable=False, server_default="{}"
)

VERSION_DRAFT = "draft"

@classmethod
def new(
cls,
@@ -179,8 +191,72 @@ class Workflow(Base):

@property
def graph_dict(self) -> Mapping[str, Any]:
# TODO(QuantumGhost): Consider caching `graph_dict` to avoid repeated JSON decoding.
#
# Using `functools.cached_property` could help, but some code in the codebase may
# modify the returned dict, which can cause issues elsewhere.
#
# For example, changing this property to a cached property led to errors like the
# following when single stepping an `Iteration` node:
#
# Root node id 1748401971780start not found in the graph
#
# There is currently no standard way to make a dict deeply immutable in Python,
# and tracking modifications to the returned dict is difficult. For now, we leave
# the code as-is to avoid these issues.
#
# Currently, the following functions / methods would mutate the returned dict:
#
# - `_get_graph_and_variable_pool_of_single_iteration`.
# - `_get_graph_and_variable_pool_of_single_loop`.
return json.loads(self.graph) if self.graph else {}

def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]:
"""Extract a node configuration from the workflow graph by node ID.
A node configuration is a dictionary containing the node's properties, including
the node's id, title, and its data as a dict.
"""
workflow_graph = self.graph_dict

if not workflow_graph:
raise WorkflowDataError(f"workflow graph not found, workflow_id={self.id}")

nodes = workflow_graph.get("nodes")
if not nodes:
raise WorkflowDataError("nodes not found in workflow graph")

try:
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
assert isinstance(node_config, dict)
return node_config

@staticmethod
def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType:
"""Extract type of a node from the node configuration returned by `get_node_config_by_id`."""
node_config_data = node_config.get("data", {})
# Get node class
node_type = NodeType(node_config_data.get("type"))
return node_type

@staticmethod
def get_enclosing_node_type_and_id(node_config: Mapping[str, Any]) -> tuple[NodeType, str] | None:
in_loop = node_config.get("isInLoop", False)
in_iteration = node_config.get("isInIteration", False)
if in_loop:
loop_id = node_config.get("loop_id")
if loop_id is None:
raise _InvalidGraphDefinitionError("invalid graph")
return NodeType.LOOP, loop_id
elif in_iteration:
iteration_id = node_config.get("iteration_id")
if iteration_id is None:
raise _InvalidGraphDefinitionError("invalid graph")
return NodeType.ITERATION, iteration_id
else:
return None

@property
def features(self) -> str:
"""
@@ -376,6 +452,10 @@ class Workflow(Base):
ensure_ascii=False,
)

@staticmethod
def version_from_datetime(d: datetime) -> str:
return str(d)


class WorkflowRun(Base):
"""
@@ -835,8 +915,18 @@ def _naive_utc_datetime():


class WorkflowDraftVariable(Base):
"""`WorkflowDraftVariable` record variables and outputs generated during
debugging worfklow or chatflow.

IMPORTANT: This model maintains multiple invariant rules that must be preserved.
Do not instantiate this class directly with the constructor.

Instead, use the factory methods (`new_conversation_variable`, `new_sys_variable`,
`new_node_variable`) defined below to ensure all invariants are properly maintained.
"""

@staticmethod
def unique_columns() -> list[str]:
def unique_app_id_node_id_name() -> list[str]:
return [
"app_id",
"node_id",
@@ -844,7 +934,9 @@ class WorkflowDraftVariable(Base):
]

__tablename__ = "workflow_draft_variables"
__table_args__ = (UniqueConstraint(*unique_columns()),)
__table_args__ = (UniqueConstraint(*unique_app_id_node_id_name()),)
# Required for instance variable annotation.
__allow_unmapped__ = True

# id is the unique identifier of a draft variable.
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
@@ -925,6 +1017,36 @@ class WorkflowDraftVariable(Base):
default=None,
)

# Cache for deserialized value
#
# NOTE(QuantumGhost): This field serves two purposes:
#
# 1. Caches deserialized values to reduce repeated parsing costs
# 2. Allows modification of the deserialized value after retrieval,
# particularly important for `File`` variables which require database
# lookups to obtain storage_key and other metadata
#
# Use double underscore prefix for better encapsulation,
# making this attribute harder to access from outside the class.
__value: Segment | None

def __init__(self, *args, **kwargs):
"""
The constructor of `WorkflowDraftVariable` is not intended for
direct use outside this file. Its solo purpose is setup private state
used by the model instance.

Please use the factory methods
(`new_conversation_variable`, `new_sys_variable`, `new_node_variable`)
defined below to create instances of this class.
"""
super().__init__(*args, **kwargs)
self.__value = None

@orm.reconstructor
def _init_on_load(self):
self.__value = None

def get_selector(self) -> list[str]:
selector = json.loads(self.selector)
if not isinstance(selector, list):
@@ -939,15 +1061,92 @@ class WorkflowDraftVariable(Base):
def _set_selector(self, value: list[str]):
self.selector = json.dumps(value)

def get_value(self) -> Segment | None:
return build_segment(json.loads(self.value))
def _loads_value(self) -> Segment:
value = json.loads(self.value)
return self.build_segment_with_type(self.value_type, value)

@staticmethod
def rebuild_file_types(value: Any) -> Any:
# NOTE(QuantumGhost): Temporary workaround for structured data handling.
# By this point, `output` has been converted to dict by
# `WorkflowEntry.handle_special_values`, so we need to
# reconstruct File objects from their serialized form
# to maintain proper variable saving behavior.
#
# Ideally, we should work with structured data objects directly
# rather than their serialized forms.
# However, multiple components in the codebase depend on
# `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging.
if isinstance(value, dict):
if not maybe_file_object(value):
return value
return File.model_validate(value)
elif isinstance(value, list) and value:
first = value[0]
if not maybe_file_object(first):
return value
return [File.model_validate(i) for i in value]
else:
return value

@classmethod
def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment:
# Extends `variable_factory.build_segment_with_type` functionality by
# reconstructing `FileSegment`` or `ArrayFileSegment`` objects from
# their serialized dictionary or list representations, respectively.
if segment_type == SegmentType.FILE:
if isinstance(value, File):
return build_segment_with_type(segment_type, value)
elif isinstance(value, dict):
file = cls.rebuild_file_types(value)
return build_segment_with_type(segment_type, file)
else:
raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
if segment_type == SegmentType.ARRAY_FILE:
if not isinstance(value, list):
raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
file_list = cls.rebuild_file_types(value)
return build_segment_with_type(segment_type=segment_type, value=file_list)

return build_segment_with_type(segment_type=segment_type, value=value)

def get_value(self) -> Segment:
"""Decode the serialized value into its corresponding `Segment` object.

This method caches the result, so repeated calls will return the same
object instance without re-parsing the serialized data.

If you need to modify the returned `Segment`, use `value.model_copy()`
to create a copy first to avoid affecting the cached instance.

For more information about the caching mechanism, see the documentation
of the `__value` field.

Returns:
Segment: The deserialized value as a Segment object.
"""

if self.__value is not None:
return self.__value
value = self._loads_value()
self.__value = value
return value

def set_name(self, name: str):
self.name = name
self._set_selector([self.node_id, name])

def set_value(self, value: Segment):
self.value = json.dumps(value.value)
"""Updates the `value` and corresponding `value_type` fields in the database model.

This method also stores the provided Segment object in the deserialized cache
without creating a copy, allowing for efficient value access.

Args:
value: The Segment object to store as the variable's value.
"""
self.__value = value
self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder)
self.value_type = value.value_type

def get_node_id(self) -> str | None:
@@ -973,6 +1172,7 @@ class WorkflowDraftVariable(Base):
node_id: str,
name: str,
value: Segment,
node_execution_id: str | None,
description: str = "",
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
@@ -984,6 +1184,7 @@ class WorkflowDraftVariable(Base):
variable.name = name
variable.set_value(value)
variable._set_selector(list(variable_utils.to_selector(node_id, name)))
variable.node_execution_id = node_execution_id
return variable

@classmethod
@@ -993,13 +1194,17 @@ class WorkflowDraftVariable(Base):
app_id: str,
name: str,
value: Segment,
description: str = "",
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
name=name,
value=value,
description=description,
node_execution_id=None,
)
variable.editable = True
return variable

@classmethod
@@ -1009,9 +1214,16 @@ class WorkflowDraftVariable(Base):
app_id: str,
name: str,
value: Segment,
node_execution_id: str,
editable: bool = False,
) -> "WorkflowDraftVariable":
variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value)
variable = cls._new(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
name=name,
node_execution_id=node_execution_id,
value=value,
)
variable.editable = editable
return variable

@@ -1023,11 +1235,19 @@ class WorkflowDraftVariable(Base):
node_id: str,
name: str,
value: Segment,
node_execution_id: str,
visible: bool = True,
editable: bool = True,
) -> "WorkflowDraftVariable":
variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value)
variable = cls._new(
app_id=app_id,
node_id=node_id,
name=name,
node_execution_id=node_execution_id,
value=value,
)
variable.visible = visible
variable.editable = True
variable.editable = editable
return variable

@property

+ 1
- 0
api/pyproject.toml Parādīt failu

@@ -149,6 +149,7 @@ dev = [
"types-ujson>=5.10.0",
"boto3-stubs>=1.38.20",
"types-jmespath>=1.0.2.20240106",
"hypothesis>=6.131.15",
"types_pyOpenSSL>=24.1.0",
"types_cffi>=1.17.0",
"types_setuptools>=80.9.0",

+ 3
- 0
api/services/app_dsl_service.py Parādīt failu

@@ -32,6 +32,7 @@ from models import Account, App, AppMode
from models.model import AppModelConfig
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
from services.workflow_service import WorkflowService

logger = logging.getLogger(__name__)
@@ -292,6 +293,8 @@ class AppDslService:
dependencies=check_dependencies_pending_data,
)

draft_var_srv = WorkflowDraftVariableService(session=self._session)
draft_var_srv.delete_workflow_variables(app_id=app.id)
return Import(
id=import_id,
status=status,

+ 4
- 0
api/services/errors/app.py Parādīt failu

@@ -4,3 +4,7 @@ class MoreLikeThisDisabledError(Exception):

class WorkflowHashNotEqualError(Exception):
pass


class IsDraftWorkflowError(Exception):
pass

+ 721
- 0
api/services/workflow_draft_variable_service.py Parādīt failu

@@ -0,0 +1,721 @@
import dataclasses
import datetime
import logging
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, ClassVar

from sqlalchemy import Engine, orm, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import and_, or_

from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.variables import Segment, StringSegment, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import ArrayFileSegment, FileSegment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables
from core.workflow.variable_loader import VariableLoader
from factories.file_factory import StorageKeyLoader
from factories.variable_factory import build_segment, segment_to_variable
from models import App, Conversation
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable

_logger = logging.getLogger(__name__)


@dataclasses.dataclass(frozen=True)
class WorkflowDraftVariableList:
variables: list[WorkflowDraftVariable]
total: int | None = None


class WorkflowDraftVariableError(Exception):
pass


class VariableResetError(WorkflowDraftVariableError):
pass


class UpdateNotSupportedError(WorkflowDraftVariableError):
pass


class DraftVarLoader(VariableLoader):
# This implements the VariableLoader interface for loading draft variables.
#
# ref: core.workflow.variable_loader.VariableLoader

# Database engine used for loading variables.
_engine: Engine
# Application ID for which variables are being loaded.
_app_id: str
_tenant_id: str
_fallback_variables: Sequence[Variable]

def __init__(
self,
engine: Engine,
app_id: str,
tenant_id: str,
fallback_variables: Sequence[Variable] | None = None,
) -> None:
self._engine = engine
self._app_id = app_id
self._tenant_id = tenant_id
self._fallback_variables = fallback_variables or []

def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
return (selector[0], selector[1])

def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
if not selectors:
return []

# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
variable_by_selector: dict[tuple[str, str], Variable] = {}

with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)

for draft_var in draft_vars:
segment = draft_var.get_value()
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
selector_tuple = self._selector_to_tuple(variable.selector)
variable_by_selector[selector_tuple] = variable

# Important:
files: list[File] = []
for draft_var in draft_vars:
value = draft_var.get_value()
if isinstance(value, FileSegment):
files.append(value.value)
elif isinstance(value, ArrayFileSegment):
files.extend(value.value)
with Session(bind=self._engine) as session:
storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id)
storage_key_loader.load_storage_keys(files)

return list(variable_by_selector.values())


class WorkflowDraftVariableService:
_session: Session

def __init__(self, session: Session) -> None:
self._session = session

def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()

def get_draft_variables_by_selectors(
self,
app_id: str,
selectors: Sequence[list[str]],
) -> list[WorkflowDraftVariable]:
ors = []
for selector in selectors:
node_id, name = selector
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))

# NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as
# each expression includes conditions on both `node_id` and `name` (which are covered by the unique index),
# PostgreSQL can efficiently retrieve the results using a bitmap index scan.
#
# Alternatively, a `SELECT` statement could be constructed for each selector and
# combined using `UNION` to fetch all rows.
# Benchmarking indicates that both approaches yield comparable performance.
variables = (
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all()
)
return variables

def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList:
criteria = WorkflowDraftVariable.app_id == app_id
total = None
query = self._session.query(WorkflowDraftVariable).filter(criteria)
if page == 1:
total = query.count()
variables = (
# Do not load the `value` field.
query.options(orm.defer(WorkflowDraftVariable.value))
.order_by(WorkflowDraftVariable.id.desc())
.limit(limit)
.offset((page - 1) * limit)
.all()
)

return WorkflowDraftVariableList(variables=variables, total=total)

def _list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
criteria = (
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
)
query = self._session.query(WorkflowDraftVariable).filter(*criteria)
variables = query.order_by(WorkflowDraftVariable.id.desc()).all()
return WorkflowDraftVariableList(variables=variables)

def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, node_id)

def list_conversation_variables(self, app_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID)

def list_system_variables(self, app_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID)

def get_conversation_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name)

def get_system_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name)

def get_node_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id, node_id, name)

def _get_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
variable = (
self._session.query(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
WorkflowDraftVariable.name == name,
)
.first()
)
return variable

def update_variable(
self,
variable: WorkflowDraftVariable,
name: str | None = None,
value: Segment | None = None,
) -> WorkflowDraftVariable:
if not variable.editable:
raise UpdateNotSupportedError(f"variable not support updating, id={variable.id}")
if name is not None:
variable.set_name(name)
if value is not None:
variable.set_value(value)
variable.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._session.flush()
return variable

def _reset_conv_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
conv_var_by_name = {i.name: i for i in workflow.conversation_variables}
conv_var = conv_var_by_name.get(variable.name)

if conv_var is None:
self._session.delete(instance=variable)
self._session.flush()
_logger.warning(
"Conversation variable not found for draft variable, id=%s, name=%s", variable.id, variable.name
)
return None

variable.set_value(conv_var)
variable.last_edited_at = None
self._session.add(variable)
self._session.flush()
return variable

def _reset_node_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
# If a variable does not allow updating, it makes no sence to resetting it.
if not variable.editable:
return variable
# No execution record for this variable, delete the variable instead.
if variable.node_execution_id is None:
self._session.delete(instance=variable)
self._session.flush()
_logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
return None

query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id)
node_exec = self._session.scalars(query).first()
if node_exec is None:
_logger.warning(
"Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",
variable.id,
variable.name,
variable.node_execution_id,
)
self._session.delete(instance=variable)
self._session.flush()
return None

# Get node type for proper value extraction
node_config = workflow.get_node_config_by_id(variable.node_id)
node_type = workflow.get_node_type_from_node_config(node_config)

outputs_dict = node_exec.outputs_dict or {}

# Note: Based on the implementation in `_build_from_variable_assigner_mapping`,
# VariableAssignerNode (both v1 and v2) can only create conversation draft variables.
# For consistency, we should simply return when processing VARIABLE_ASSIGNER nodes.
#
# This implementation must remain synchronized with the `_build_from_variable_assigner_mapping`
# and `save` methods.
if node_type == NodeType.VARIABLE_ASSIGNER:
return variable

if variable.name not in outputs_dict:
# If variable not found in execution data, delete the variable
self._session.delete(instance=variable)
self._session.flush()
return None
value = outputs_dict[variable.name]
value_seg = WorkflowDraftVariable.build_segment_with_type(variable.value_type, value)
# Extract variable value using unified logic
variable.set_value(value_seg)
variable.last_edited_at = None # Reset to indicate this is a reset operation
self._session.flush()
return variable

def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
variable_type = variable.get_variable_type()
if variable_type == DraftVariableType.CONVERSATION:
return self._reset_conv_var(workflow, variable)
elif variable_type == DraftVariableType.NODE:
return self._reset_node_var(workflow, variable)
else:
raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}")

def delete_variable(self, variable: WorkflowDraftVariable):
self._session.delete(variable)

def delete_workflow_variables(self, app_id: str):
(
self._session.query(WorkflowDraftVariable)
.filter(WorkflowDraftVariable.app_id == app_id)
.delete(synchronize_session=False)
)

def delete_node_variables(self, app_id: str, node_id: str):
return self._delete_node_variables(app_id, node_id)

def _delete_node_variables(self, app_id: str, node_id: str):
self._session.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
).delete()

def _get_conversation_id_from_draft_variable(self, app_id: str) -> str | None:
draft_var = self._get_variable(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
name=str(SystemVariableKey.CONVERSATION_ID),
)
if draft_var is None:
return None
segment = draft_var.get_value()
if not isinstance(segment, StringSegment):
_logger.warning(
"sys.conversation_id variable is not a string: app_id=%s, id=%s",
app_id,
draft_var.id,
)
return None
return segment.value

def get_or_create_conversation(
self,
account_id: str,
app: App,
workflow: Workflow,
) -> str:
"""
get_or_create_conversation creates and returns the ID of a conversation for debugging.

If a conversation already exists, as determined by the following criteria, its ID is returned:
- The system variable `sys.conversation_id` exists in the draft variable table, and
- A corresponding conversation record is found in the database.

If no such conversation exists, a new conversation is created and its ID is returned.
"""
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id)

if conv_id is not None:
conversation = (
self._session.query(Conversation)
.filter(
Conversation.id == conv_id,
Conversation.app_id == workflow.app_id,
)
.first()
)
# Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB).
if conversation is not None:
return conv_id
conversation = Conversation(
app_id=workflow.app_id,
app_model_config_id=app.app_model_config_id,
model_provider=None,
model_id="",
override_model_configs=None,
mode=app.mode,
name="Draft Debugging Conversation",
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=InvokeFrom.DEBUGGER.value,
from_source="console",
from_end_user_id=None,
from_account_id=account_id,
)

self._session.add(conversation)
self._session.flush()
return conversation.id

def prefill_conversation_variable_default_values(self, workflow: Workflow):
""""""
draft_conv_vars: list[WorkflowDraftVariable] = []
for conv_var in workflow.conversation_variables:
draft_var = WorkflowDraftVariable.new_conversation_variable(
app_id=workflow.app_id,
name=conv_var.name,
value=conv_var,
description=conv_var.description,
)
draft_conv_vars.append(draft_var)
_batch_upsert_draft_varaible(
self._session,
draft_conv_vars,
policy=_UpsertPolicy.IGNORE,
)


class _UpsertPolicy(StrEnum):
IGNORE = "ignore"
OVERWRITE = "overwrite"


def _batch_upsert_draft_varaible(
session: Session,
draft_vars: Sequence[WorkflowDraftVariable],
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
) -> None:
if not draft_vars:
return None
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
#
# 1. The variable saving process involves writing multiple rows to the
# `workflow_draft_variables` table. Batch insertion significantly improves performance.
# 2. Using the ORM would require either:
#
# a. Checking for the existence of each variable before insertion,
# resulting in 2n SQL statements for n variables and potential concurrency issues.
# b. Attempting insertion first, then updating if a unique index violation occurs,
# which still results in n to 2n SQL statements.
#
# Both approaches are inefficient and suboptimal.
# 3. We do not need to retrieve the results of the SQL execution or populate ORM
# model instances with the returned values.
# 4. Batch insertion with `ON CONFLICT DO UPDATE` allows us to insert or update all
# variables in a single SQL statement, avoiding the issues above.
#
# For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific
# insert operations instead of the ORM layer.
stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars])
if policy == _UpsertPolicy.OVERWRITE:
stmt = stmt.on_conflict_do_update(
index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(),
set_={
"updated_at": stmt.excluded.updated_at,
"last_edited_at": stmt.excluded.last_edited_at,
"description": stmt.excluded.description,
"value_type": stmt.excluded.value_type,
"value": stmt.excluded.value,
"visible": stmt.excluded.visible,
"editable": stmt.excluded.editable,
"node_execution_id": stmt.excluded.node_execution_id,
},
)
elif _UpsertPolicy.IGNORE:
stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name())
else:
raise Exception("Invalid value for update policy.")
session.execute(stmt)


def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
d: dict[str, Any] = {
"app_id": model.app_id,
"last_edited_at": None,
"node_id": model.node_id,
"name": model.name,
"selector": model.selector,
"value_type": model.value_type,
"value": model.value,
"node_execution_id": model.node_execution_id,
}
if model.visible is not None:
d["visible"] = model.visible
if model.editable is not None:
d["editable"] = model.editable
if model.created_at is not None:
d["created_at"] = model.created_at
if model.updated_at is not None:
d["updated_at"] = model.updated_at
if model.description is not None:
d["description"] = model.description
return d


def _build_segment_for_serialized_values(v: Any) -> Segment:
"""
Reconstructs Segment objects from serialized values, with special handling
for FileSegment and ArrayFileSegment types.

This function should only be used when:
1. No explicit type information is available
2. The input value is in serialized form (dict or list)

It detects potential file objects in the serialized data and properly rebuilds the
appropriate segment type.
"""
return build_segment(WorkflowDraftVariable.rebuild_file_types(v))


class DraftVariableSaver:
# _DUMMY_OUTPUT_IDENTITY is a placeholder output for workflow nodes.
# Its sole possible value is `None`.
#
# This is used to signal the execution of a workflow node when it has no other outputs.
_DUMMY_OUTPUT_IDENTITY: ClassVar[str] = "__dummy__"
_DUMMY_OUTPUT_VALUE: ClassVar[None] = None

# _EXCLUDE_VARIABLE_NAMES_MAPPING maps node types and versions to variable names that
# should be excluded when saving draft variables. This prevents certain internal or
# technical variables from being exposed in the draft environment, particularly those
# that aren't meant to be directly edited or viewed by users.
_EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = {
NodeType.LLM: frozenset(["finish_reason"]),
NodeType.LOOP: frozenset(["loop_round"]),
}

# Database session used for persisting draft variables.
_session: Session

# The application ID associated with the draft variables.
# This should match the `Workflow.app_id` of the workflow to which the current node belongs.
_app_id: str

# The ID of the node for which DraftVariableSaver is saving output variables.
_node_id: str

# The type of the current node (see NodeType).
_node_type: NodeType

# Indicates how the workflow execution was triggered (see InvokeFrom).
_invoke_from: InvokeFrom

#
_node_execution_id: str

# _enclosing_node_id identifies the container node that the current node belongs to.
# For example, if the current node is an LLM node inside an Iteration node
# or Loop node, then `_enclosing_node_id` refers to the ID of
# the containing Iteration or Loop node.
#
# If the current node is not nested within another node, `_enclosing_node_id` is
# `None`.
_enclosing_node_id: str | None

def __init__(
self,
session: Session,
app_id: str,
node_id: str,
node_type: NodeType,
invoke_from: InvokeFrom,
node_execution_id: str,
enclosing_node_id: str | None = None,
):
self._session = session
self._app_id = app_id
self._node_id = node_id
self._node_type = node_type
self._invoke_from = invoke_from
self._node_execution_id = node_execution_id
self._enclosing_node_id = enclosing_node_id

def _create_dummy_output_variable(self):
return WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
node_id=self._node_id,
name=self._DUMMY_OUTPUT_IDENTITY,
node_execution_id=self._node_execution_id,
value=build_segment(self._DUMMY_OUTPUT_VALUE),
visible=False,
editable=False,
)

def _should_save_output_variables_for_draft(self) -> bool:
# Only save output variables for debugging execution of workflow.
if self._invoke_from != InvokeFrom.DEBUGGER:
return False
if self._enclosing_node_id is not None and self._node_type != NodeType.VARIABLE_ASSIGNER:
# Currently we do not save output variables for nodes inside loop or iteration.
return False
return True

def _build_from_variable_assigner_mapping(self, process_data: Mapping[str, Any]) -> list[WorkflowDraftVariable]:
draft_vars: list[WorkflowDraftVariable] = []
updated_variables = get_updated_variables(process_data) or []

for item in updated_variables:
selector = item.selector
if len(selector) < MIN_SELECTORS_LENGTH:
raise Exception("selector too short")
# NOTE(QuantumGhost): only the following two kinds of variable could be updated by
# VariableAssigner: ConversationVariable and iteration variable.
# We only save conversation variable here.
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
segment = WorkflowDraftVariable.build_segment_with_type(segment_type=item.value_type, value=item.new_value)
draft_vars.append(
WorkflowDraftVariable.new_conversation_variable(
app_id=self._app_id,
name=item.name,
value=segment,
)
)
# Add a dummy output variable to indicate that this node is executed.
draft_vars.append(self._create_dummy_output_variable())
return draft_vars

def _build_variables_from_start_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]:
draft_vars = []
has_non_sys_variables = False
for name, value in output.items():
value_seg = _build_segment_for_serialized_values(value)
node_id, name = self._normalize_variable_for_start_node(name)
# If node_id is not `sys`, it means that the variable is a user-defined input field
# in `Start` node.
if node_id != SYSTEM_VARIABLE_NODE_ID:
draft_vars.append(
WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
node_id=self._node_id,
name=name,
node_execution_id=self._node_execution_id,
value=value_seg,
visible=True,
editable=True,
)
)
has_non_sys_variables = True
else:
if name == SystemVariableKey.FILES:
# Here we know the type of variable must be `array[file]`, we
# just build files from the value.
files = [File.model_validate(v) for v in value]
if files:
value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files)
else:
value_seg = ArrayFileSegment(value=[])

draft_vars.append(
WorkflowDraftVariable.new_sys_variable(
app_id=self._app_id,
name=name,
node_execution_id=self._node_execution_id,
value=value_seg,
editable=self._should_variable_be_editable(node_id, name),
)
)
if not has_non_sys_variables:
draft_vars.append(self._create_dummy_output_variable())
return draft_vars

def _normalize_variable_for_start_node(self, name: str) -> tuple[str, str]:
if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."):
return self._node_id, name
_, name_ = name.split(".", maxsplit=1)
return SYSTEM_VARIABLE_NODE_ID, name_

def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]:
draft_vars = []
for name, value in output.items():
if not self._should_variable_be_saved(name):
_logger.debug(
"Skip saving variable as it has been excluded by its node_type, name=%s, node_type=%s",
name,
self._node_type,
)
continue
if isinstance(value, Segment):
value_seg = value
else:
value_seg = _build_segment_for_serialized_values(value)
draft_vars.append(
WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
node_id=self._node_id,
name=name,
node_execution_id=self._node_execution_id,
value=value_seg,
visible=self._should_variable_be_visible(self._node_id, self._node_type, name),
)
)
return draft_vars

def save(
self,
process_data: Mapping[str, Any] | None = None,
outputs: Mapping[str, Any] | None = None,
):
draft_vars: list[WorkflowDraftVariable] = []
if outputs is None:
outputs = {}
if process_data is None:
process_data = {}
if not self._should_save_output_variables_for_draft():
return
if self._node_type == NodeType.VARIABLE_ASSIGNER:
draft_vars = self._build_from_variable_assigner_mapping(process_data=process_data)
elif self._node_type == NodeType.START:
draft_vars = self._build_variables_from_start_mapping(outputs)
else:
draft_vars = self._build_variables_from_mapping(outputs)
_batch_upsert_draft_varaible(self._session, draft_vars)

@staticmethod
def _should_variable_be_editable(node_id: str, name: str) -> bool:
if node_id in (CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID):
return False
if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name):
return False
return True

@staticmethod
def _should_variable_be_visible(node_id: str, node_type: NodeType, name: str) -> bool:
if node_type in NodeType.IF_ELSE:
return False
if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name):
return False
return True

def _should_variable_be_saved(self, name: str) -> bool:
exclude_var_names = self._EXCLUDE_VARIABLE_NAMES_MAPPING.get(self._node_type)
if exclude_var_names is None:
return True
return name not in exclude_var_names

+ 225
- 17
api/services/workflow_service.py Parādīt failu

@@ -1,6 +1,7 @@
import json
import time
from collections.abc import Callable, Generator, Sequence
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional
from uuid import uuid4
@@ -8,12 +9,17 @@ from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session

from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes import NodeType
@@ -22,9 +28,11 @@ from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
@@ -34,10 +42,15 @@ from models.workflow import (
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
from services.errors.app import WorkflowHashNotEqualError
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter

from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
from .workflow_draft_variable_service import (
DraftVariableSaver,
DraftVarLoader,
WorkflowDraftVariableService,
)


class WorkflowService:
@@ -45,6 +58,33 @@ class WorkflowService:
Workflow Service
"""

def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None:
# TODO(QuantumGhost): This query is not fully covered by index.
criteria = (
WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id,
WorkflowNodeExecutionModel.app_id == app_model.id,
WorkflowNodeExecutionModel.workflow_id == workflow.id,
WorkflowNodeExecutionModel.node_id == node_id,
)
node_exec = (
db.session.query(WorkflowNodeExecutionModel)
.filter(*criteria)
.order_by(WorkflowNodeExecutionModel.created_at.desc())
.first()
)
return node_exec

def is_workflow_exist(self, app_model: App) -> bool:
return (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
.count()
) > 0

def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
"""
Get draft workflow
@@ -61,6 +101,23 @@ class WorkflowService:
# return draft workflow
return workflow

def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id,
)
.first()
)
if not workflow:
return None
if workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}")
return workflow

def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
"""
Get published workflow
@@ -199,7 +256,7 @@ class WorkflowService:
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type=draft_workflow.type,
version=str(datetime.now(UTC).replace(tzinfo=None)),
version=Workflow.version_from_datetime(datetime.now(UTC).replace(tzinfo=None)),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
@@ -253,26 +310,85 @@ class WorkflowService:
return default_config

def run_draft_workflow_node(
self, app_model: App, node_id: str, user_inputs: dict, account: Account
self,
app_model: App,
draft_workflow: Workflow,
node_id: str,
user_inputs: Mapping[str, Any],
account: Account,
query: str = "",
files: Sequence[File] | None = None,
) -> WorkflowNodeExecutionModel:
"""
Run draft workflow node
"""
# fetch draft workflow by app_model
draft_workflow = self.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError("Workflow not initialized")
files = files or []

with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)

node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
node_data = node_config.get("data", {})
if node_type == NodeType.START:
with Session(bind=db.engine) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
conversation_id = draft_var_srv.get_or_create_conversation(
account_id=account.id,
app=app_model,
workflow=draft_workflow,
)
start_data = StartNodeData.model_validate(node_data)
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
)
# init variable pool
variable_pool = _setup_variable_pool(
query=query,
files=files or [],
user_id=account.id,
user_inputs=user_inputs,
workflow=draft_workflow,
# NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables.
conversation_variables=[],
node_type=node_type,
conversation_id=conversation_id,
)

else:
variable_pool = VariablePool(
system_variables={},
user_inputs=user_inputs,
environment_variables=draft_workflow.environment_variables,
conversation_variables=[],
)

variable_loader = DraftVarLoader(
engine=db.engine,
app_id=app_model.id,
tenant_id=app_model.tenant_id,
)

eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
if eclosing_node_type_and_id:
_, enclosing_node_id = eclosing_node_type_and_id
else:
enclosing_node_id = None

run = WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
variable_pool=variable_pool,
variable_loader=variable_loader,
)

# run draft workflow node
start_at = time.perf_counter()

node_execution = self._handle_node_run_result(
invoke_node_fn=lambda: WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
),
invoke_node_fn=lambda: run,
start_at=start_at,
node_id=node_id,
)
@@ -292,6 +408,18 @@ class WorkflowService:
# Convert node_execution to WorkflowNodeExecution after save
workflow_node_execution = repository.to_db_model(node_execution)

with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
session=session,
app_id=app_model.id,
node_id=workflow_node_execution.node_id,
node_type=NodeType(workflow_node_execution.node_type),
invoke_from=InvokeFrom.DEBUGGER,
enclosing_node_id=enclosing_node_id,
node_execution_id=node_execution.id,
)
draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs)
session.commit()
return workflow_node_execution

def run_free_workflow_node(
@@ -332,7 +460,7 @@ class WorkflowService:
node_run_result = event.run_result

# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
# node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break

if not node_run_result:
@@ -394,7 +522,7 @@ class WorkflowService:
if node_run_result.process_data
else None
)
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
outputs = node_run_result.outputs

node_execution.inputs = inputs
node_execution.process_data = process_data
@@ -531,3 +659,83 @@ class WorkflowService:

session.delete(workflow)
return True


def _setup_variable_pool(
query: str,
files: Sequence[File],
user_id: str,
user_inputs: Mapping[str, Any],
workflow: Workflow,
node_type: NodeType,
conversation_id: str,
conversation_variables: list[Variable],
):
# Only inject system variables for START node type.
if node_type == NodeType.START:
# Create a variable pool.
system_inputs: dict[SystemVariableKey, Any] = {
# From inputs:
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
# From workflow model
SystemVariableKey.APP_ID: workflow.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
# Randomly generated.
SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()),
}

# Only add chatflow-specific variables for non-workflow types
if workflow.type != WorkflowType.WORKFLOW.value:
system_inputs.update(
{
SystemVariableKey.QUERY: query,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.DIALOGUE_COUNT: 0,
}
)
else:
system_inputs = {}

# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)

return variable_pool


def _rebuild_file_for_user_inputs_in_start_node(
tenant_id: str, start_node_data: StartNodeData, user_inputs: Mapping[str, Any]
) -> Mapping[str, Any]:
inputs_copy = dict(user_inputs)

for variable in start_node_data.variables:
if variable.type not in (VariableEntityType.FILE, VariableEntityType.FILE_LIST):
continue
if variable.variable not in user_inputs:
continue
value = user_inputs[variable.variable]
file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type)
inputs_copy[variable.variable] = file
return inputs_copy


def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]:
if variable_entity_type == VariableEntityType.FILE:
if not isinstance(value, dict):
raise ValueError(f"expected dict for file object, got {type(value)}")
return build_from_mapping(mapping=value, tenant_id=tenant_id)
elif variable_entity_type == VariableEntityType.FILE_LIST:
if not isinstance(value, list):
raise ValueError(f"expected list for file list object, got {type(value)}")
if len(value) == 0:
return []
if not isinstance(value[0], dict):
raise ValueError(f"expected dict for first element in the file list, got {type(value)}")
return build_from_mappings(mappings=value, tenant_id=tenant_id)
else:
raise Exception("unreachable")

+ 209
- 99
api/tests/integration_tests/.env.example Parādīt failu

@@ -1,107 +1,217 @@
# OpenAI API Key
OPENAI_API_KEY=
FLASK_APP=app.py
FLASK_DEBUG=0
SECRET_KEY='uhySf6a3aZuvRNfAlcr47paOw9TRYBY6j8ZHXpVw1yx5RP27Yj3w2uvI'

CONSOLE_API_URL=http://127.0.0.1:5001
CONSOLE_WEB_URL=http://127.0.0.1:3000

# Service API base URL
SERVICE_API_URL=http://127.0.0.1:5001

# Web APP base URL
APP_WEB_URL=http://127.0.0.1:3000

# Files URL
FILES_URL=http://127.0.0.1:5001

# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300

# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60

# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30

# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1

# redis configuration
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_USE_SSL=false
REDIS_DB=0

# PostgreSQL database configuration
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
DB_HOST=localhost
DB_PORT=5432
DB_DATABASE=dify

# Storage configuration
# use for store upload files, private keys...
# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
STORAGE_TYPE=opendal

# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
OPENDAL_SCHEME=fs
OPENDAL_FS_ROOT=storage

# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*

# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
VECTOR_STORE=weaviate
# Weaviate configuration
WEAVIATE_ENDPOINT=http://localhost:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100


# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50

# Model configuration
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=4096
CODE_GENERATION_MAX_TOKENS=1024

# Mail configuration, support: resend, smtp
MAIL_TYPE=
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@example.com>
RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
# smtp configuration
SMTP_SERVER=smtp.example.com
SMTP_PORT=465
SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false

# Sentry configuration
SENTRY_DSN=

# DEBUG
DEBUG=false
SQLALCHEMY_ECHO=false

# Notion import configuration, support public and internal
NOTION_INTEGRATION_TYPE=public
NOTION_CLIENT_SECRET=you-client-secret
NOTION_CLIENT_ID=you-client-id
NOTION_INTERNAL_SECRET=you-internal-secret

ETL_TYPE=dify
UNSTRUCTURED_API_URL=
UNSTRUCTURED_API_KEY=
SCARF_NO_ANALYTICS=false

#ssrf
SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL=
SSRF_DEFAULT_MAX_RETRIES=3
SSRF_DEFAULT_TIME_OUT=5
SSRF_DEFAULT_CONNECT_TIME_OUT=5
SSRF_DEFAULT_READ_TIME_OUT=5
SSRF_DEFAULT_WRITE_TIME_OUT=5

BATCH_UPLOAD_LIMIT=10
KEYWORD_DATA_SOURCE_TYPE=database

# Workflow file upload limit
WORKFLOW_FILE_UPLOAD_LIMIT=10

# Azure OpenAI API Base Endpoint & API Key
AZURE_OPENAI_API_BASE=
AZURE_OPENAI_API_KEY=

# Anthropic API Key
ANTHROPIC_API_KEY=

# Replicate API Key
REPLICATE_API_KEY=

# Hugging Face API Key
HUGGINGFACE_API_KEY=
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL=
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL=
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL=

# Minimax Credentials
MINIMAX_API_KEY=
MINIMAX_GROUP_ID=

# Spark Credentials
SPARK_APP_ID=
SPARK_API_KEY=
SPARK_API_SECRET=

# Tongyi Credentials
TONGYI_DASHSCOPE_API_KEY=

# Wenxin Credentials
WENXIN_API_KEY=
WENXIN_SECRET_KEY=

# ZhipuAI Credentials
ZHIPUAI_API_KEY=

# Baichuan Credentials
BAICHUAN_API_KEY=
BAICHUAN_SECRET_KEY=

# ChatGLM Credentials
CHATGLM_API_BASE=

# Xinference Credentials
XINFERENCE_SERVER_URL=
XINFERENCE_GENERATION_MODEL_UID=
XINFERENCE_CHAT_MODEL_UID=
XINFERENCE_EMBEDDINGS_MODEL_UID=
XINFERENCE_RERANK_MODEL_UID=

# OpenLLM Credentials
OPENLLM_SERVER_URL=

# LocalAI Credentials
LOCALAI_SERVER_URL=

# Cohere Credentials
COHERE_API_KEY=

# Jina Credentials
JINA_API_KEY=

# Ollama Credentials
OLLAMA_BASE_URL=
# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
CODE_EXECUTION_API_KEY=dify-sandbox
CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=80000
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
CODE_MAX_STRING_ARRAY_LENGTH=30
CODE_MAX_OBJECT_ARRAY_LENGTH=30
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
# API Tool configuration
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
API_TOOL_DEFAULT_READ_TIMEOUT=60
# HTTP Node configuration
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
HTTP_REQUEST_MAX_READ_TIMEOUT=600
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
# Respect X-* headers to redirect clients
RESPECT_XFORWARD_HEADERS_ENABLED=false
# Log file path
LOG_FILE=
# Log file max size, the unit is MB
LOG_FILE_MAX_SIZE=20
# Log file max backup count
LOG_FILE_BACKUP_COUNT=5
# Log dateformat
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
# Log format
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
# Indexing configuration
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
# Workflow runtime configuration
WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=
POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=

# Together API Key
TOGETHER_API_KEY=
# Plugin configuration
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1

# Mock Switch
MOCK_SWITCH=false
# Marketplace configuration
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai

# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=
CODE_EXECUTION_API_KEY=
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}

# Volcengine MaaS Credentials
VOLC_API_KEY=
VOLC_SECRET_KEY=
VOLC_MODEL_ENDPOINT_ID=
VOLC_EMBEDDING_ENDPOINT_ID=
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5

# 360 AI Credentials
ZHINAO_API_KEY=
CREATE_TIDB_SERVICE_JOB_ENABLED=false

# Plugin configuration
PLUGIN_DAEMON_KEY=
PLUGIN_DAEMON_URL=
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100
# Lockout duration in seconds
LOGIN_LOCKOUT_DURATION=86400

# Marketplace configuration
MARKETPLACE_API_URL=
# VESSL AI Credentials
VESSL_AI_MODEL_NAME=
VESSL_AI_API_KEY=
VESSL_AI_ENDPOINT_URL=

# GPUStack Credentials
GPUSTACK_SERVER_URL=
GPUSTACK_API_KEY=

# Gitee AI Credentials
GITEE_AI_API_KEY=

# xAI Credentials
XAI_API_KEY=
XAI_API_BASE=
HTTP_PROXY='http://127.0.0.1:1092'
HTTPS_PROXY='http://127.0.0.1:1092'
NO_PROXY='localhost,127.0.0.1'
LOG_LEVEL=INFO

+ 80
- 8
api/tests/integration_tests/conftest.py Parādīt failu

@@ -1,19 +1,91 @@
import os
import pathlib
import random
import secrets
from collections.abc import Generator

# Getting the absolute path of the current file's directory
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
import pytest
from flask import Flask
from flask.testing import FlaskClient
from sqlalchemy.orm import Session

# Getting the absolute path of the project's root directory
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
from app_factory import create_app
from models import Account, DifySetup, Tenant, TenantAccountJoin, db
from services.account_service import AccountService, RegisterService


# Loading the .env file if it exists
def _load_env() -> None:
dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env")
if os.path.exists(dotenv_path):
current_file_path = pathlib.Path(__file__).absolute()
# Items later in the list have higher precedence.
files_to_load = [".env", "vdb.env"]

env_file_paths = [current_file_path.parent / i for i in files_to_load]
for path in env_file_paths:
if not path.exists():
continue

from dotenv import load_dotenv

load_dotenv(dotenv_path)
# Set `override=True` to ensure values from `vdb.env` take priority over values from `.env`
load_dotenv(str(path), override=True)


_load_env()

_CACHED_APP = create_app()


@pytest.fixture
def flask_app() -> Flask:
return _CACHED_APP


@pytest.fixture(scope="session")
def setup_account(request) -> Generator[Account, None, None]:
"""`dify_setup` completes the setup process for the Dify application.

It creates `Account` and `Tenant`, and inserts a `DifySetup` record into the database.

Most tests in the `controllers` package may require dify has been successfully setup.
"""
with _CACHED_APP.test_request_context():
rand_suffix = random.randint(int(1e6), int(1e7)) # noqa
name = f"test-user-{rand_suffix}"
email = f"{name}@example.com"
RegisterService.setup(
email=email,
name=name,
password=secrets.token_hex(16),
ip_address="localhost",
)

with _CACHED_APP.test_request_context():
with Session(bind=db.engine, expire_on_commit=False) as session:
account = session.query(Account).filter_by(email=email).one()

yield account

with _CACHED_APP.test_request_context():
db.session.query(DifySetup).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Account).delete()
db.session.query(Tenant).delete()
db.session.commit()


@pytest.fixture
def flask_req_ctx():
with _CACHED_APP.test_request_context():
yield


@pytest.fixture
def auth_header(setup_account) -> dict[str, str]:
token = AccountService.get_account_jwt_token(setup_account)
return {"Authorization": f"Bearer {token}"}


@pytest.fixture
def test_client() -> Generator[FlaskClient, None, None]:
with _CACHED_APP.test_client() as client:
yield client

+ 0
- 25
api/tests/integration_tests/controllers/app_fixture.py Parādīt failu

@@ -1,25 +0,0 @@
import pytest

from app_factory import create_app
from configs import dify_config

mock_user = type(
"MockUser",
(object,),
{
"is_authenticated": True,
"id": "123",
"is_editor": True,
"is_dataset_editor": True,
"status": "active",
"get_id": "123",
"current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b",
},
)


@pytest.fixture
def app():
app = create_app()
dify_config.LOGIN_DISABLED = True
return app

+ 0
- 0
api/tests/integration_tests/controllers/console/__init__.py Parādīt failu


+ 0
- 0
api/tests/integration_tests/controllers/console/app/__init__.py Parādīt failu


+ 47
- 0
api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py Parādīt failu

@@ -0,0 +1,47 @@
import uuid
from unittest import mock

from controllers.console.app import workflow_draft_variable as draft_variable_api
from controllers.console.app import wraps
from factories.variable_factory import build_segment
from models import App, AppMode
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService


def _get_mock_srv_class() -> type[WorkflowDraftVariableService]:
return mock.create_autospec(WorkflowDraftVariableService)


class TestWorkflowDraftNodeVariableListApi:
def test_get(self, test_client, auth_header, monkeypatch):
srv_class = _get_mock_srv_class()
mock_app_model: App = App()
mock_app_model.id = str(uuid.uuid4())
test_node_id = "test_node_id"
mock_app_model.mode = AppMode.ADVANCED_CHAT
mock_load_app_model = mock.Mock(return_value=mock_app_model)

monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)

var1 = WorkflowDraftVariable.new_node_variable(
app_id="test_app_1",
node_id="test_node_1",
name="str_var",
value=build_segment("str_value"),
node_execution_id=str(uuid.uuid4()),
)
srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True)
srv_class.return_value = srv_instance
srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1])

response = test_client.get(
f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables",
headers=auth_header,
)
assert response.status_code == 200
response_dict = response.json
assert isinstance(response_dict, dict)
assert "items" in response_dict
assert len(response_dict["items"]) == 1

+ 0
- 9
api/tests/integration_tests/controllers/test_controllers.py Parādīt failu

@@ -1,9 +0,0 @@
from unittest.mock import patch

from app_fixture import mock_user # type: ignore


def test_post_requires_login(app):
with app.test_client() as client, patch("flask_login.utils._get_user", mock_user):
response = client.get("/console/api/data-source/integrates")
assert response.status_code == 200

+ 0
- 0
api/tests/integration_tests/factories/__init__.py Parādīt failu


+ 371
- 0
api/tests/integration_tests/factories/test_storage_key_loader.py Parādīt failu

@@ -0,0 +1,371 @@
import unittest
from datetime import UTC, datetime
from typing import Optional
from unittest.mock import patch
from uuid import uuid4

import pytest
from sqlalchemy.orm import Session

from core.file import File, FileTransferMethod, FileType
from extensions.ext_database import db
from factories.file_factory import StorageKeyLoader
from models import ToolFile, UploadFile
from models.enums import CreatorUserRole


@pytest.mark.usefixtures("flask_req_ctx")
class TestStorageKeyLoader(unittest.TestCase):
"""
Integration tests for StorageKeyLoader class.

Tests the batched loading of storage keys from the database for files
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
"""

def setUp(self):
"""Set up test data before each test method."""
self.session = db.session()
self.tenant_id = str(uuid4())
self.user_id = str(uuid4())
self.conversation_id = str(uuid4())

# Create test data that will be cleaned up after each test
self.test_upload_files = []
self.test_tool_files = []

# Create StorageKeyLoader instance
self.loader = StorageKeyLoader(self.session, self.tenant_id)

def tearDown(self):
"""Clean up test data after each test method."""
self.session.rollback()

def _create_upload_file(
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
) -> UploadFile:
"""Helper method to create an UploadFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if storage_key is None:
storage_key = f"test_storage_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id

upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
key=storage_key,
name="test_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_at=datetime.now(UTC),
used=False,
)
upload_file.id = file_id

self.session.add(upload_file)
self.session.flush()
self.test_upload_files.append(upload_file)

return upload_file

def _create_tool_file(
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
) -> ToolFile:
"""Helper method to create a ToolFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if file_key is None:
file_key = f"test_file_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id

tool_file = ToolFile()
tool_file.id = file_id
tool_file.user_id = self.user_id
tool_file.tenant_id = tenant_id
tool_file.conversation_id = self.conversation_id
tool_file.file_key = file_key
tool_file.mimetype = "text/plain"
tool_file.original_url = "http://example.com/file.txt"
tool_file.name = "test_tool_file.txt"
tool_file.size = 2048

self.session.add(tool_file)
self.session.flush()
self.test_tool_files.append(tool_file)

return tool_file

def _create_file(
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
) -> File:
"""Helper method to create a File object for testing."""
if tenant_id is None:
tenant_id = self.tenant_id

# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
file_related_id = None
remote_url = None

if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
file_related_id = related_id
elif transfer_method == FileTransferMethod.REMOTE_URL:
remote_url = "https://example.com/test_file.txt"
file_related_id = related_id

return File(
id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=file_related_id,
remote_url=remote_url,
filename="test_file.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="initial_key",
)

def test_load_storage_keys_local_file(self):
"""Test loading storage keys for LOCAL_FILE transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)

# Load storage keys
self.loader.load_storage_keys([file])

# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key

def test_load_storage_keys_remote_url(self):
"""Test loading storage keys for REMOTE_URL transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)

# Load storage keys
self.loader.load_storage_keys([file])

# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key

def test_load_storage_keys_tool_file(self):
"""Test loading storage keys for TOOL_FILE transfer method."""
# Create test data
tool_file = self._create_tool_file()
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)

# Load storage keys
self.loader.load_storage_keys([file])

# Verify storage key was loaded correctly
assert file._storage_key == tool_file.file_key

def test_load_storage_keys_mixed_methods(self):
"""Test batch loading with mixed transfer methods."""
# Create test data for different transfer methods
upload_file1 = self._create_upload_file()
upload_file2 = self._create_upload_file()
tool_file = self._create_tool_file()

file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)

files = [file1, file2, file3]

# Load storage keys
self.loader.load_storage_keys(files)

# Verify all storage keys were loaded correctly
assert file1._storage_key == upload_file1.key
assert file2._storage_key == upload_file2.key
assert file3._storage_key == tool_file.file_key

def test_load_storage_keys_empty_list(self):
"""Test with empty file list."""
# Should not raise any exceptions
self.loader.load_storage_keys([])

def test_load_storage_keys_tenant_mismatch(self):
"""Test tenant_id validation."""
# Create file with different tenant_id
upload_file = self._create_upload_file()
file = self._create_file(
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
)

# Should raise ValueError for tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])

assert "invalid file, expected tenant_id" in str(context.value)

def test_load_storage_keys_missing_file_id(self):
"""Test with None file.related_id."""
# Create a file with valid parameters first, then manually set related_id to None
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = None

# Should raise ValueError for None file related_id
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])

assert str(context.value) == "file id should not be None."

def test_load_storage_keys_nonexistent_upload_file_records(self):
"""Test with missing UploadFile database records."""
# Create file with non-existent upload file id
non_existent_id = str(uuid4())
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)

# Should raise ValueError for missing record
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])

def test_load_storage_keys_nonexistent_tool_file_records(self):
"""Test with missing ToolFile database records."""
# Create file with non-existent tool file id
non_existent_id = str(uuid4())
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)

# Should raise ValueError for missing record
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])

def test_load_storage_keys_invalid_uuid(self):
"""Test with invalid UUID format."""
# Create a file with valid parameters first, then manually set invalid related_id
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = "invalid-uuid-format"

# Should raise ValueError for invalid UUID
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])

def test_load_storage_keys_batch_efficiency(self):
"""Test batched operations use efficient queries."""
# Create multiple files of different types
upload_files = [self._create_upload_file() for _ in range(3)]
tool_files = [self._create_tool_file() for _ in range(2)]

files = []
files.extend(
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
)
files.extend(
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
)

# Mock the session to count queries
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
self.loader.load_storage_keys(files)

# Should make exactly 2 queries (one for upload_files, one for tool_files)
assert mock_scalars.call_count == 2

# Verify all storage keys were loaded correctly
for i, file in enumerate(files[:3]):
assert file._storage_key == upload_files[i].key
for i, file in enumerate(files[3:]):
assert file._storage_key == tool_files[i].file_key

def test_load_storage_keys_tenant_isolation(self):
"""Test that tenant isolation works correctly."""
# Create files for different tenants
other_tenant_id = str(uuid4())

# Create upload file for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)

# Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile(
tenant_id=other_tenant_id,
storage_type="local",
key="other_tenant_key",
name="other_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_at=datetime.now(UTC),
used=False,
)
upload_file_other.id = str(uuid4())
self.session.add(upload_file_other)
self.session.flush()

# Create file for other tenant but try to load with current tenant's loader
file_other = self._create_file(
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)

# Should raise ValueError due to tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_other])

assert "invalid file, expected tenant_id" in str(context.value)

# Current tenant's file should still work
self.loader.load_storage_keys([file_current])
assert file_current._storage_key == upload_file_current.key

def test_load_storage_keys_mixed_tenant_batch(self):
"""Test batch with mixed tenant files (should fail on first mismatch)."""
# Create files for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)

# Create file for different tenant
other_tenant_id = str(uuid4())
file_other = self._create_file(
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)

# Should raise ValueError on tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_current, file_other])

assert "invalid file, expected tenant_id" in str(context.value)

def test_load_storage_keys_duplicate_file_ids(self):
"""Test handling of duplicate file IDs in the batch."""
# Create upload file
upload_file = self._create_upload_file()

# Create two File objects with same related_id
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)

# Should handle duplicates gracefully
self.loader.load_storage_keys([file1, file2])

# Both files should have the same storage key
assert file1._storage_key == upload_file.key
assert file2._storage_key == upload_file.key

def test_load_storage_keys_session_isolation(self):
"""Test that the loader uses the provided session correctly."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)

# Create loader with different session (same underlying connection)

with Session(bind=db.engine) as other_session:
other_loader = StorageKeyLoader(other_session, self.tenant_id)
with pytest.raises(ValueError):
other_loader.load_storage_keys([file])

+ 0
- 0
api/tests/integration_tests/services/__init__.py Parādīt failu


+ 501
- 0
api/tests/integration_tests/services/test_workflow_draft_variable_service.py Parādīt failu

@@ -0,0 +1,501 @@
import json
import unittest
import uuid

import pytest
from sqlalchemy.orm import Session

from core.variables.variables import StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from factories.variable_factory import build_segment
from libs import datetime_utils
from models import db
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
from services.workflow_draft_variable_service import DraftVarLoader, VariableResetError, WorkflowDraftVariableService


@pytest.mark.usefixtures("flask_req_ctx")
class TestWorkflowDraftVariableService(unittest.TestCase):
_test_app_id: str
_session: Session
_node1_id = "test_node_1"
_node2_id = "test_node_2"
_node_exec_id = str(uuid.uuid4())

def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._session: Session = db.session()
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
name="sys_var",
value=build_segment("sys_value"),
node_execution_id=self._node_exec_id,
)
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
name="conv_var",
value=build_segment("conv_value"),
)
node2_vars = [
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node2_id,
name="int_var",
value=build_segment(1),
visible=False,
node_execution_id=self._node_exec_id,
),
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node2_id,
name="str_var",
value=build_segment("str_value"),
visible=True,
node_execution_id=self._node_exec_id,
),
]
node1_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node1_id,
name="str_var",
value=build_segment("str_value"),
visible=True,
node_execution_id=self._node_exec_id,
)
_variables = list(node2_vars)
_variables.extend(
[
node1_var,
sys_var,
conv_var,
]
)

db.session.add_all(_variables)
db.session.flush()
self._variable_ids = [v.id for v in _variables]
self._node1_str_var_id = node1_var.id
self._sys_var_id = sys_var.id
self._conv_var_id = conv_var.id
self._node2_var_ids = [v.id for v in node2_vars]

def _get_test_srv(self) -> WorkflowDraftVariableService:
return WorkflowDraftVariableService(session=self._session)

def tearDown(self):
self._session.rollback()

def test_list_variables(self):
srv = self._get_test_srv()
var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2)
assert var_list.total == 5
assert len(var_list.variables) == 2
page1_var_ids = {v.id for v in var_list.variables}
assert page1_var_ids.issubset(self._variable_ids)

var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2)
assert var_list_2.total is None
assert len(var_list_2.variables) == 2
page2_var_ids = {v.id for v in var_list_2.variables}
assert page2_var_ids.isdisjoint(page1_var_ids)
assert page2_var_ids.issubset(self._variable_ids)

def test_get_node_variable(self):
srv = self._get_test_srv()
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var")
assert node_var is not None
assert node_var.id == self._node1_str_var_id
assert node_var.name == "str_var"
assert node_var.get_value() == build_segment("str_value")

def test_get_system_variable(self):
srv = self._get_test_srv()
sys_var = srv.get_system_variable(self._test_app_id, "sys_var")
assert sys_var is not None
assert sys_var.id == self._sys_var_id
assert sys_var.name == "sys_var"
assert sys_var.get_value() == build_segment("sys_value")

def test_get_conversation_variable(self):
srv = self._get_test_srv()
conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var")
assert conv_var is not None
assert conv_var.id == self._conv_var_id
assert conv_var.name == "conv_var"
assert conv_var.get_value() == build_segment("conv_value")

def test_delete_node_variables(self):
srv = self._get_test_srv()
srv.delete_node_variables(self._test_app_id, self._node2_id)
node2_var_count = (
self._session.query(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == self._test_app_id,
WorkflowDraftVariable.node_id == self._node2_id,
)
.count()
)
assert node2_var_count == 0

def test_delete_variable(self):
srv = self._get_test_srv()
node_1_var = (
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one()
)
srv.delete_variable(node_1_var)
exists = bool(
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first()
)
assert exists is False

def test__list_node_variables(self):
srv = self._get_test_srv()
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
assert len(node_vars.variables) == 2
assert {v.id for v in node_vars.variables} == set(self._node2_var_ids)

def test_get_draft_variables_by_selectors(self):
srv = self._get_test_srv()
selectors = [
[self._node1_id, "str_var"],
[self._node2_id, "str_var"],
[self._node2_id, "int_var"],
]
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors)
assert len(variables) == 3
assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids)


@pytest.mark.usefixtures("flask_req_ctx")
class TestDraftVariableLoader(unittest.TestCase):
_test_app_id: str
_test_tenant_id: str

_node1_id = "test_loader_node_1"
_node_exec_id = str(uuid.uuid4())

def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._test_tenant_id = str(uuid.uuid4())
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
name="sys_var",
value=build_segment("sys_value"),
node_execution_id=self._node_exec_id,
)
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
name="conv_var",
value=build_segment("conv_value"),
)
node_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node1_id,
name="str_var",
value=build_segment("str_value"),
visible=True,
node_execution_id=self._node_exec_id,
)
_variables = [
node_var,
sys_var,
conv_var,
]

with Session(bind=db.engine, expire_on_commit=False) as session:
session.add_all(_variables)
session.flush()
session.commit()
self._variable_ids = [v.id for v in _variables]
self._node_var_id = node_var.id
self._sys_var_id = sys_var.id
self._conv_var_id = conv_var.id

def tearDown(self):
with Session(bind=db.engine, expire_on_commit=False) as session:
session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete(
synchronize_session=False
)
session.commit()

def test_variable_loader_with_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
variables = var_loader.load_variables([])
assert len(variables) == 0

def test_variable_loader_with_non_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
variables = var_loader.load_variables(
[
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],
[CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
[self._node1_id, "str_var"],
]
)
assert len(variables) == 3
conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID)
assert conv_var.id == self._conv_var_id
sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
assert sys_var.id == self._sys_var_id
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
assert node1_var.id == self._node_var_id


@pytest.mark.usefixtures("flask_req_ctx")
class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
"""Integration tests for reset_variable functionality using real database"""

_test_app_id: str
_test_tenant_id: str
_test_workflow_id: str
_session: Session
_node_id = "test_reset_node"
_node_exec_id: str
_workflow_node_exec_id: str

def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._test_tenant_id = str(uuid.uuid4())
self._test_workflow_id = str(uuid.uuid4())
self._node_exec_id = str(uuid.uuid4())
self._workflow_node_exec_id = str(uuid.uuid4())
self._session: Session = db.session()

# Create a workflow node execution record with outputs
# Note: The WorkflowNodeExecutionModel.id should match the node_execution_id in WorkflowDraftVariable
self._workflow_node_execution = WorkflowNodeExecutionModel(
id=self._node_exec_id, # This should match the node_execution_id in the variable
tenant_id=self._test_tenant_id,
app_id=self._test_app_id,
workflow_id=self._test_workflow_id,
triggered_from="workflow-run",
workflow_run_id=str(uuid.uuid4()),
index=1,
node_execution_id=self._node_exec_id,
node_id=self._node_id,
node_type=NodeType.LLM.value,
title="Test Node",
inputs='{"input": "test input"}',
process_data='{"test_var": "process_value", "other_var": "other_process"}',
outputs='{"test_var": "output_value", "other_var": "other_output"}',
status="succeeded",
elapsed_time=1.5,
created_by_role="account",
created_by=str(uuid.uuid4()),
)

# Create conversation variables for the workflow
self._conv_variables = [
StringVariable(
id=str(uuid.uuid4()),
name="conv_var_1",
description="Test conversation variable 1",
value="default_value_1",
),
StringVariable(
id=str(uuid.uuid4()),
name="conv_var_2",
description="Test conversation variable 2",
value="default_value_2",
),
]

# Create test variables
self._node_var_with_exec = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node_id,
name="test_var",
value=build_segment("old_value"),
node_execution_id=self._node_exec_id,
)
self._node_var_with_exec.last_edited_at = datetime_utils.naive_utc_now()

self._node_var_without_exec = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node_id,
name="no_exec_var",
value=build_segment("some_value"),
node_execution_id="temp_exec_id",
)
# Manually set node_execution_id to None after creation
self._node_var_without_exec.node_execution_id = None

self._node_var_missing_exec = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node_id,
name="missing_exec_var",
value=build_segment("some_value"),
node_execution_id=str(uuid.uuid4()), # Use a valid UUID that doesn't exist in database
)

self._conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
name="conv_var_1",
value=build_segment("old_conv_value"),
)
self._conv_var.last_edited_at = datetime_utils.naive_utc_now()

# Add all to database
db.session.add_all(
[
self._workflow_node_execution,
self._node_var_with_exec,
self._node_var_without_exec,
self._node_var_missing_exec,
self._conv_var,
]
)
db.session.flush()

# Store IDs for assertions
self._node_var_with_exec_id = self._node_var_with_exec.id
self._node_var_without_exec_id = self._node_var_without_exec.id
self._node_var_missing_exec_id = self._node_var_missing_exec.id
self._conv_var_id = self._conv_var.id

def _get_test_srv(self) -> WorkflowDraftVariableService:
return WorkflowDraftVariableService(session=self._session)

def _create_mock_workflow(self) -> Workflow:
"""Create a real workflow with conversation variables and graph"""
conversation_vars = self._conv_variables

# Create a simple graph with the test node
graph = {
"nodes": [{"id": "test_reset_node", "type": "llm", "title": "Test Node", "data": {"type": "llm"}}],
"edges": [],
}

workflow = Workflow.new(
tenant_id=str(uuid.uuid4()),
app_id=self._test_app_id,
type="workflow",
version="1.0",
graph=json.dumps(graph),
features="{}",
created_by=str(uuid.uuid4()),
environment_variables=[],
conversation_variables=conversation_vars,
)
return workflow

def tearDown(self):
self._session.rollback()

def test_reset_node_variable_with_valid_execution_record(self):
"""Test resetting a node variable with valid execution record - should restore from execution"""
srv = self._get_test_srv()
mock_workflow = self._create_mock_workflow()

# Get the variable before reset
variable = srv.get_variable(self._node_var_with_exec_id)
assert variable is not None
assert variable.get_value().value == "old_value"
assert variable.last_edited_at is not None

# Reset the variable
result = srv.reset_variable(mock_workflow, variable)

# Should return the updated variable
assert result is not None
assert result.id == self._node_var_with_exec_id
assert result.node_execution_id == self._workflow_node_execution.id
assert result.last_edited_at is None # Should be reset to None

# The returned variable should have the updated value from execution record
assert result.get_value().value == "output_value"

# Verify the variable was updated in database
updated_variable = srv.get_variable(self._node_var_with_exec_id)
assert updated_variable is not None
# The value should be updated from the execution record's outputs
assert updated_variable.get_value().value == "output_value"
assert updated_variable.last_edited_at is None
assert updated_variable.node_execution_id == self._workflow_node_execution.id

def test_reset_node_variable_with_no_execution_id(self):
"""Test resetting a node variable with no execution ID - should delete variable"""
srv = self._get_test_srv()
mock_workflow = self._create_mock_workflow()

# Get the variable before reset
variable = srv.get_variable(self._node_var_without_exec_id)
assert variable is not None

# Reset the variable
result = srv.reset_variable(mock_workflow, variable)

# Should return None (variable deleted)
assert result is None

# Verify the variable was deleted
deleted_variable = srv.get_variable(self._node_var_without_exec_id)
assert deleted_variable is None

def test_reset_node_variable_with_missing_execution_record(self):
"""Test resetting a node variable when execution record doesn't exist"""
srv = self._get_test_srv()
mock_workflow = self._create_mock_workflow()

# Get the variable before reset
variable = srv.get_variable(self._node_var_missing_exec_id)
assert variable is not None

# Reset the variable
result = srv.reset_variable(mock_workflow, variable)

# Should return None (variable deleted)
assert result is None

# Verify the variable was deleted
deleted_variable = srv.get_variable(self._node_var_missing_exec_id)
assert deleted_variable is None

def test_reset_conversation_variable(self):
"""Test resetting a conversation variable"""
srv = self._get_test_srv()
mock_workflow = self._create_mock_workflow()

# Get the variable before reset
variable = srv.get_variable(self._conv_var_id)
assert variable is not None
assert variable.get_value().value == "old_conv_value"
assert variable.last_edited_at is not None

# Reset the variable
result = srv.reset_variable(mock_workflow, variable)

# Should return the updated variable
assert result is not None
assert result.id == self._conv_var_id
assert result.last_edited_at is None # Should be reset to None

# Verify the variable was updated with default value from workflow
updated_variable = srv.get_variable(self._conv_var_id)
assert updated_variable is not None
# The value should be updated from the workflow's conversation variable default
assert updated_variable.get_value().value == "default_value_1"
assert updated_variable.last_edited_at is None

def test_reset_system_variable_raises_error(self):
"""Test that resetting a system variable raises an error"""
srv = self._get_test_srv()
mock_workflow = self._create_mock_workflow()

# Create a system variable
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
name="sys_var",
value=build_segment("sys_value"),
node_execution_id=self._node_exec_id,
)
db.session.add(sys_var)
db.session.flush()

# Attempt to reset the system variable
with pytest.raises(VariableResetError) as exc_info:
srv.reset_variable(mock_workflow, sys_var)

assert "cannot reset system variable" in str(exc_info.value)
assert sys_var.id in str(exc_info.value)

+ 179
- 198
api/tests/integration_tests/workflow/nodes/test_llm.py Parādīt failu

@@ -8,8 +8,6 @@ from unittest.mock import MagicMock, patch

import pytest

from app_factory import create_app
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
@@ -30,21 +28,6 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock


@pytest.fixture(scope="session")
def app():
# Set up storage configuration
os.environ["STORAGE_TYPE"] = "opendal"
os.environ["OPENDAL_SCHEME"] = "fs"
os.environ["OPENDAL_FS_ROOT"] = "storage"

# Ensure storage directory exists
os.makedirs("storage", exist_ok=True)

app = create_app()
dify_config.LOGIN_DISABLED = True
return app


def init_llm_node(config: dict) -> LLMNode:
graph_config = {
"edges": [
@@ -102,197 +85,195 @@ def init_llm_node(config: dict) -> LLMNode:
return node


def test_execute_llm(app):
with app.app_context():
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {
"provider": "langgenius/openai/openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
def test_execute_llm(flask_req_ctx):
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {
"provider": "langgenius/openai/openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
)

credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}

# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)

mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")

mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)

# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result

# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "langgenius/openai/openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"

# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config

# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance

with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
assert isinstance(result, Generator)

for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
},
)

credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}

# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)

mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")

mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)

# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result

# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "langgenius/openai/openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"

# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config

# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance

with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
assert isinstance(result, Generator)

for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0


@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_llm_with_jinja2(app, setup_code_executor_mock):
def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
"""
Test execute LLM node with jinja2
"""
with app.app_context():
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_config": {
"jinja2_variables": [
{"variable": "sys_query", "value_selector": ["sys", "query"]},
{"variable": "output", "value_selector": ["abc", "output"]},
]
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
},
{
"role": "user",
"text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_config": {
"jinja2_variables": [
{"variable": "sys_query", "value_selector": ["sys", "query"]},
{"variable": "output", "value_selector": ["abc", "output"]},
]
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
},
{
"role": "user",
"text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
)

# Mock db.session.close()
db.session.close = MagicMock()

# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)

mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")

mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)

# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result

# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"

# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config

# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance

with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()

for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
},
)

# Mock db.session.close()
db.session.close = MagicMock()

# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)

mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")

mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)

# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result

# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"

# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config

# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance

with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()

for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data)


def test_extract_json():

+ 302
- 0
api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py Parādīt failu

@@ -0,0 +1,302 @@
import datetime
import uuid
from collections import OrderedDict
from typing import Any, NamedTuple

from flask_restful import marshal

from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList

_TEST_APP_ID = "test_app_id"
_TEST_NODE_EXEC_ID = str(uuid.uuid4())


class TestWorkflowDraftVariableFields:
def test_conversation_variable(self):
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
)

conv_var.id = str(uuid.uuid4())
conv_var.visible = True

expected_without_value: OrderedDict[str, Any] = OrderedDict(
{
"id": str(conv_var.id),
"type": conv_var.get_variable_type().value,
"name": "conv_var",
"description": "",
"selector": [CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
"value_type": "number",
"edited": False,
"visible": True,
}
)

assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = 1
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value

def test_create_sys_variable(self):
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=_TEST_APP_ID,
name="sys_var",
value=build_segment("a"),
editable=True,
node_execution_id=_TEST_NODE_EXEC_ID,
)

sys_var.id = str(uuid.uuid4())
sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
sys_var.visible = True

expected_without_value = OrderedDict(
{
"id": str(sys_var.id),
"type": sys_var.get_variable_type().value,
"name": "sys_var",
"description": "",
"selector": [SYSTEM_VARIABLE_NODE_ID, "sys_var"],
"value_type": "string",
"edited": True,
"visible": True,
}
)
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = "a"
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value

def test_node_variable(self):
node_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="node_var",
value=build_segment([1, "a"]),
visible=False,
node_execution_id=_TEST_NODE_EXEC_ID,
)

node_var.id = str(uuid.uuid4())
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)

expected_without_value: OrderedDict[str, Any] = OrderedDict(
{
"id": str(node_var.id),
"type": node_var.get_variable_type().value,
"name": "node_var",
"description": "",
"selector": ["test_node", "node_var"],
"value_type": "array[any]",
"edited": True,
"visible": False,
}
)

assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = [1, "a"]
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value


class TestWorkflowDraftVariableList:
def test_workflow_draft_variable_list(self):
class TestCase(NamedTuple):
name: str
var_list: WorkflowDraftVariableList
expected: dict

node_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="test_var",
value=build_segment("a"),
visible=True,
node_execution_id=_TEST_NODE_EXEC_ID,
)
node_var.id = str(uuid.uuid4())
node_var_dict = OrderedDict(
{
"id": str(node_var.id),
"type": node_var.get_variable_type().value,
"name": "test_var",
"description": "",
"selector": ["test_node", "test_var"],
"value_type": "string",
"edited": False,
"visible": True,
}
)

cases = [
TestCase(
name="empty variable list",
var_list=WorkflowDraftVariableList(variables=[]),
expected=OrderedDict(
{
"items": [],
"total": None,
}
),
),
TestCase(
name="empty variable list with total",
var_list=WorkflowDraftVariableList(variables=[], total=10),
expected=OrderedDict(
{
"items": [],
"total": 10,
}
),
),
TestCase(
name="non-empty variable list",
var_list=WorkflowDraftVariableList(variables=[node_var], total=None),
expected=OrderedDict(
{
"items": [node_var_dict],
"total": None,
}
),
),
TestCase(
name="non-empty variable list with total",
var_list=WorkflowDraftVariableList(variables=[node_var], total=10),
expected=OrderedDict(
{
"items": [node_var_dict],
"total": 10,
}
),
),
]

for idx, case in enumerate(cases, 1):
assert marshal(case.var_list, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) == case.expected, (
f"Test case {idx} failed, {case.name=}"
)


def test_workflow_node_variables_fields():
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
)
resp = marshal(WorkflowDraftVariableList(variables=[conv_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "conv_var"
assert item_dict["value"] == 1


def test_workflow_file_variable_with_signed_url():
"""Test that File type variables include signed URLs in API responses."""
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File

# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
test_file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_upload_file_id",
filename="test.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=12345,
)

# Create a WorkflowDraftVariable with the File
file_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="file_var",
value=build_segment(test_file),
node_execution_id=_TEST_NODE_EXEC_ID,
)

# Marshal the variable using the API fields
resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)

# Verify the response structure
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "file_var"

# Verify the value is a dict (File.to_dict() result) and contains expected fields
value = item_dict["value"]
assert isinstance(value, dict)

# Verify the File fields are preserved
assert value["id"] == test_file.id
assert value["filename"] == test_file.filename
assert value["type"] == test_file.type.value
assert value["transfer_method"] == test_file.transfer_method.value
assert value["size"] == test_file.size

# Verify the URL is present (it should be a signed URL for LOCAL_FILE transfer method)
remote_url = value["remote_url"]
assert remote_url is not None

assert isinstance(remote_url, str)
# For LOCAL_FILE, the URL should contain signature parameters
assert "timestamp=" in remote_url
assert "nonce=" in remote_url
assert "sign=" in remote_url


def test_workflow_file_variable_remote_url():
"""Test that File type variables with REMOTE_URL transfer method return the remote URL."""
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File

# Create a File object with REMOTE_URL transfer method
test_file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=12345,
)

# Create a WorkflowDraftVariable with the File
file_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="file_var",
value=build_segment(test_file),
node_execution_id=_TEST_NODE_EXEC_ID,
)

# Marshal the variable using the API fields
resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)

# Verify the response structure
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "file_var"

# Verify the value is a dict (File.to_dict() result) and contains expected fields
value = item_dict["value"]
assert isinstance(value, dict)
remote_url = value["remote_url"]

# For REMOTE_URL, the URL should be the original remote URL
assert remote_url == test_file.remote_url

+ 0
- 165
api/tests/unit_tests/core/app/segments/test_factory.py Parādīt failu

@@ -1,165 +0,0 @@
from uuid import uuid4

import pytest

from core.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectSegment,
SecretVariable,
StringVariable,
)
from core.variables.exc import VariableError
from core.variables.segments import ArrayAnySegment
from factories import variable_factory


def test_string_variable():
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, StringVariable)


def test_integer_variable():
test_data = {"value_type": "number", "name": "test_int", "value": 42}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, IntegerVariable)


def test_float_variable():
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, FloatVariable)


def test_secret_variable():
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, SecretVariable)


def test_invalid_value_type():
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(test_data)


def test_build_a_blank_string():
result = variable_factory.build_conversation_variable_from_mapping(
{
"value_type": "string",
"name": "blank",
"value": "",
}
)
assert isinstance(result, StringVariable)
assert result.value == ""


def test_build_a_object_variable_with_none_value():
var = variable_factory.build_segment(
{
"key1": None,
}
)
assert isinstance(var, ObjectSegment)
assert var.value["key1"] is None


def test_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "object",
"name": "test_object",
"description": "Description of the variable.",
"value": {
"key1": "text",
"key2": 2,
},
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value["key1"], str)
assert isinstance(variable.value["key2"], int)


def test_array_string_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[string]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
"text",
"text",
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayStringVariable)
assert isinstance(variable.value[0], str)
assert isinstance(variable.value[1], str)


def test_array_number_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[number]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
1,
2.0,
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayNumberVariable)
assert isinstance(variable.value[0], int)
assert isinstance(variable.value[1], float)


def test_array_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[object]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
{
"key1": "text",
"key2": 1,
},
{
"key1": "text",
"key2": 1,
},
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
assert isinstance(variable.value[0]["key1"], str)
assert isinstance(variable.value[0]["key2"], int)
assert isinstance(variable.value[1]["key1"], str)
assert isinstance(variable.value[1]["key2"], int)


def test_variable_cannot_large_than_200_kb():
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(
{
"id": str(uuid4()),
"value_type": "string",
"name": "test_text",
"value": "a" * 1024 * 201,
}
)


def test_array_none_variable():
var = variable_factory.build_segment([None, None, None, None])
assert isinstance(var, ArrayAnySegment)
assert var.value == [None, None, None, None]

+ 25
- 0
api/tests/unit_tests/core/file/test_models.py Parādīt failu

@@ -0,0 +1,25 @@
from core.file import File, FileTransferMethod, FileType


def test_file():
file = File(
id="test-file",
tenant_id="test-tenant-id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="test-related-id",
filename="image.png",
extension=".png",
mime_type="image/png",
size=67,
storage_key="test-storage-key",
url="https://example.com/image.png",
)
assert file.tenant_id == "test-tenant-id"
assert file.type == FileType.IMAGE
assert file.transfer_method == FileTransferMethod.TOOL_FILE
assert file.related_id == "test-related-id"
assert file.filename == "image.png"
assert file.extension == ".png"
assert file.mime_type == "image/png"
assert file.size == 67

api/tests/unit_tests/core/app/segments/test_segment.py → api/tests/unit_tests/core/variables/test_segment.py Parādīt failu


api/tests/unit_tests/core/app/segments/test_variables.py → api/tests/unit_tests/core/variables/test_variables.py Parādīt failu


+ 7
- 6
api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py Parādīt failu

@@ -3,6 +3,7 @@ import uuid
from unittest.mock import patch

from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables.segments import ArrayAnySegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@@ -197,7 +198,7 @@ def test_run():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}

assert count == 20

@@ -413,7 +414,7 @@ def test_run_parallel():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}

assert count == 32

@@ -654,7 +655,7 @@ def test_iteration_run_in_parallel_mode():
parallel_arr.append(item)
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 32

for item in sequential_result:
@@ -662,7 +663,7 @@ def test_iteration_run_in_parallel_mode():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 64


@@ -846,7 +847,7 @@ def test_iteration_run_error_handle():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": [None, None]}
assert item.run_result.outputs == {"output": ArrayAnySegment(value=[None, None])}

assert count == 14
# execute remove abnormal output
@@ -857,5 +858,5 @@ def test_iteration_run_error_handle():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": []}
assert item.run_result.outputs == {"output": ArrayAnySegment(value=[])}
assert count == 14

+ 16
- 3
api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py Parādīt failu

@@ -7,6 +7,7 @@ from docx.oxml.text.paragraph import CT_P

from core.file import File, FileTransferMethod
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment
from core.variables.variables import StringVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@@ -69,7 +70,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
@pytest.mark.parametrize(
("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
[
("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
(
"text/plain",
b"Hello, world!",
["Hello, world!"],
FileTransferMethod.LOCAL_FILE,
".txt",
),
(
"application/pdf",
b"%PDF-1.5\n%Test PDF content",
@@ -84,7 +91,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
FileTransferMethod.REMOTE_URL,
"",
),
("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None),
(
"text/plain",
b"Remote content",
["Remote content"],
FileTransferMethod.REMOTE_URL,
None,
),
],
)
def test_run_extract_text(
@@ -131,7 +144,7 @@ def test_run_extract_text(
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error
assert result.outputs is not None
assert result.outputs["text"] == expected_text
assert result.outputs["text"] == ArrayStringSegment(value=expected_text)

if transfer_method == FileTransferMethod.REMOTE_URL:
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")

+ 1
- 1
api/tests/unit_tests/core/workflow/nodes/test_list_operator.py Parādīt failu

@@ -115,7 +115,7 @@ def test_filter_files_by_type(list_operator_node):
},
]
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
for expected_file, result_file in zip(expected_files, result.outputs["result"]):
for expected_file, result_file in zip(expected_files, result.outputs["result"].value):
assert expected_file["filename"] == result_file.filename
assert expected_file["type"] == result_file.type
assert expected_file["tenant_id"] == result_file.tenant_id

+ 54
- 12
api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py Parādīt failu

@@ -5,6 +5,7 @@ from uuid import uuid4

from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
@@ -63,10 +64,11 @@ def test_overwrite_string_variable():
name="test_string_variable",
value="the second value",
)
conversation_id = str(uuid.uuid4())

# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -77,6 +79,9 @@ def test_overwrite_string_variable():
input_variable,
)

mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)

node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
@@ -91,11 +96,20 @@ def test_overwrite_string_variable():
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory,
)

with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=input_variable.value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()

got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@@ -148,9 +162,10 @@ def test_append_variable_to_array():
name="test_string_variable",
value="the second value",
)
conversation_id = str(uuid.uuid4())

variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -160,6 +175,9 @@ def test_append_variable_to_array():
input_variable,
)

mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)

node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
@@ -174,11 +192,22 @@ def test_append_variable_to_array():
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory,
)

with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=expected_value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()

got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@@ -225,13 +254,17 @@ def test_clear_array():
value=["the first value"],
)

conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)

mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)

node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
@@ -246,11 +279,20 @@ def test_clear_array():
"input_variable_selector": [],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory,
)

with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=[],
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()

got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None

+ 39
- 0
api/tests/unit_tests/core/workflow/test_variable_pool.py Parādīt failu

@@ -1,8 +1,12 @@
import pytest
from pydantic import ValidationError

from core.file import File, FileTransferMethod, FileType
from core.variables import FileSegment, StringSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from factories.variable_factory import build_segment, segment_to_variable


@pytest.fixture
@@ -44,3 +48,38 @@ def test_use_long_selector(pool):
result = pool.get(("node_1", "part_1", "part_2"))
assert result is not None
assert result.value == "test_value"


class TestVariablePool:
def test_constructor(self):
pool = VariablePool()
pool = VariablePool(
variable_dictionary={},
user_inputs={},
system_variables={},
environment_variables=[],
conversation_variables=[],
)

pool = VariablePool(
user_inputs={"key": "value"},
system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"},
environment_variables=[
segment_to_variable(
segment=build_segment(1),
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"],
name="env_var_1",
)
],
conversation_variables=[
segment_to_variable(
segment=build_segment("1"),
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"],
name="conv_var_1",
)
],
)

def test_constructor_with_invalid_system_variable_key(self):
with pytest.raises(ValidationError):
VariablePool(system_variables={"invalid_key": "value"}) # type: ignore

+ 34
- 14
api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py Parādīt failu

@@ -1,22 +1,10 @@
from core.variables import SecretVariable
import dataclasses

from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.utils import variable_template_parser


def test_extract_selectors_from_template():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
],
conversation_variables=[],
)
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
@@ -26,3 +14,35 @@ def test_extract_selectors_from_template():
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
]


def test_invalid_references():
@dataclasses.dataclass
class TestCase:
name: str
template: str

cases = [
TestCase(
name="lack of closing brace",
template="Hello, {{#sys.user_id#",
),
TestCase(
name="lack of opening brace",
template="Hello, #sys.user_id#}}",
),
TestCase(
name="lack selector name",
template="Hello, {{#sys#}}",
),
TestCase(
name="empty node name part",
template="Hello, {{#.user_id#}}",
),
]
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
selectors = variable_template_parser.extract_selectors_from_template(c.template)
assert selectors == [], fail_msg
parser = variable_template_parser.VariableTemplateParser(c.template)
assert parser.extract_variable_selectors() == [], fail_msg

+ 0
- 0
api/tests/unit_tests/factories/test_variable_factory.py Parādīt failu


Daži faili netika attēloti, jo izmaiņu fails ir pārāk liels

Notiek ielāde…
Atcelt
Saglabāt