| @@ -83,9 +83,15 @@ jobs: | |||
| compose-file: | | |||
| docker/docker-compose.middleware.yaml | |||
| services: | | |||
| db | |||
| redis | |||
| sandbox | |||
| ssrf_proxy | |||
| - name: setup test config | |||
| run: | | |||
| cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env | |||
| - name: Run Workflow | |||
| run: uv run --project api bash dev/pytest/pytest_workflow.sh | |||
| @@ -84,6 +84,12 @@ jobs: | |||
| elasticsearch | |||
| oceanbase | |||
| - name: setup test config | |||
| run: | | |||
| echo $(pwd) | |||
| ls -lah . | |||
| cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env | |||
| - name: Check VDB Ready (TiDB) | |||
| run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py | |||
| @@ -230,6 +230,10 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/) | |||
| Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Using Alibaba Cloud Data Management | |||
| One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) | |||
| ## Contributing | |||
| @@ -211,6 +211,11 @@ docker compose up -d | |||
| #### استخدام Alibaba Cloud للنشر | |||
| [بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### استخدام Alibaba Cloud Data Management للنشر | |||
| انشر Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) | |||
| ## المساهمة | |||
| @@ -229,6 +229,10 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন | |||
| [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Alibaba Cloud Data Management ব্যবহার করে ডিপ্লয় | |||
| [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) | |||
| ## Contributing | |||
| @@ -225,6 +225,10 @@ docker compose up -d | |||
| 使用 [阿里云计算巢](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) 将 Dify 一键部署到 阿里云 | |||
| #### 使用 阿里云数据管理DMS 部署 | |||
| 使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云 | |||
| ## Star History | |||
| @@ -225,6 +225,10 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/) | |||
| [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Alibaba Cloud Data Management | |||
| Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) | |||
| ## Contributing | |||
| @@ -225,6 +225,11 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) | |||
| [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Alibaba Cloud Data Management | |||
| Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) | |||
| ## Contribuir | |||
| Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). | |||
| @@ -223,6 +223,10 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) | |||
| [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Alibaba Cloud Data Management | |||
| Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) | |||
| ## Contribuer | |||
| @@ -155,7 +155,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ | |||
| [こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回のGPT-4呼び出しが無料で含まれています。 | |||
| - **Dify Community Editionのセルフホスティング</br>** | |||
| この[スタートガイド](#quick-start)を使用して、ローカル環境でDifyを簡単に実行できます。 | |||
| この[スタートガイド](#クイックスタート)を使用して、ローカル環境でDifyを簡単に実行できます。 | |||
| 詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。 | |||
| - **企業/組織向けのDify</br>** | |||
| @@ -223,6 +223,9 @@ docker compose up -d | |||
| #### Alibaba Cloud | |||
| [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Alibaba Cloud Data Management | |||
| [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます | |||
| ## 貢献 | |||
| @@ -223,6 +223,10 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo | |||
| [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Alibaba Cloud Data Management | |||
| [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) | |||
| ## Contributing | |||
| @@ -217,6 +217,10 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 | |||
| [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Alibaba Cloud Data Management | |||
| [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다 | |||
| ## 기여 | |||
| @@ -222,6 +222,10 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) | |||
| [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Alibaba Cloud Data Management | |||
| Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) | |||
| ## Contribuindo | |||
| @@ -223,6 +223,10 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) | |||
| [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Alibaba Cloud Data Management | |||
| Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) | |||
| ## Prispevam | |||
| @@ -216,6 +216,10 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter | |||
| [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Alibaba Cloud Data Management | |||
| [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın | |||
| ## Katkıda Bulunma | |||
| @@ -228,6 +228,10 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify | |||
| [阿里云](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### 使用 阿里雲數據管理DMS 進行部署 | |||
| 透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲 | |||
| ## 貢獻 | |||
| @@ -219,6 +219,10 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) | |||
| [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) | |||
| #### Alibaba Cloud Data Management | |||
| Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) | |||
| ## Đóng góp | |||
| @@ -1,6 +1,4 @@ | |||
| exclude = [ | |||
| "migrations/*", | |||
| ] | |||
| exclude = ["migrations/*"] | |||
| line-length = 120 | |||
| [format] | |||
| @@ -9,14 +7,14 @@ quote-style = "double" | |||
| [lint] | |||
| preview = false | |||
| select = [ | |||
| "B", # flake8-bugbear rules | |||
| "C4", # flake8-comprehensions | |||
| "E", # pycodestyle E rules | |||
| "F", # pyflakes rules | |||
| "FURB", # refurb rules | |||
| "I", # isort rules | |||
| "N", # pep8-naming | |||
| "PT", # flake8-pytest-style rules | |||
| "B", # flake8-bugbear rules | |||
| "C4", # flake8-comprehensions | |||
| "E", # pycodestyle E rules | |||
| "F", # pyflakes rules | |||
| "FURB", # refurb rules | |||
| "I", # isort rules | |||
| "N", # pep8-naming | |||
| "PT", # flake8-pytest-style rules | |||
| "PLC0208", # iteration-over-set | |||
| "PLC0414", # useless-import-alias | |||
| "PLE0604", # invalid-all-object | |||
| @@ -24,19 +22,19 @@ select = [ | |||
| "PLR0402", # manual-from-import | |||
| "PLR1711", # useless-return | |||
| "PLR1714", # repeated-equality-comparison | |||
| "RUF013", # implicit-optional | |||
| "RUF019", # unnecessary-key-check | |||
| "RUF100", # unused-noqa | |||
| "RUF101", # redirected-noqa | |||
| "RUF200", # invalid-pyproject-toml | |||
| "RUF022", # unsorted-dunder-all | |||
| "S506", # unsafe-yaml-load | |||
| "SIM", # flake8-simplify rules | |||
| "TRY400", # error-instead-of-exception | |||
| "TRY401", # verbose-log-message | |||
| "UP", # pyupgrade rules | |||
| "W191", # tab-indentation | |||
| "W605", # invalid-escape-sequence | |||
| "RUF013", # implicit-optional | |||
| "RUF019", # unnecessary-key-check | |||
| "RUF100", # unused-noqa | |||
| "RUF101", # redirected-noqa | |||
| "RUF200", # invalid-pyproject-toml | |||
| "RUF022", # unsorted-dunder-all | |||
| "S506", # unsafe-yaml-load | |||
| "SIM", # flake8-simplify rules | |||
| "TRY400", # error-instead-of-exception | |||
| "TRY401", # verbose-log-message | |||
| "UP", # pyupgrade rules | |||
| "W191", # tab-indentation | |||
| "W605", # invalid-escape-sequence | |||
| # security related linting rules | |||
| # RCE proctection (sort of) | |||
| "S102", # exec-builtin, disallow use of `exec` | |||
| @@ -47,36 +45,37 @@ select = [ | |||
| ] | |||
| ignore = [ | |||
| "E402", # module-import-not-at-top-of-file | |||
| "E711", # none-comparison | |||
| "E712", # true-false-comparison | |||
| "E721", # type-comparison | |||
| "E722", # bare-except | |||
| "F821", # undefined-name | |||
| "F841", # unused-variable | |||
| "E402", # module-import-not-at-top-of-file | |||
| "E711", # none-comparison | |||
| "E712", # true-false-comparison | |||
| "E721", # type-comparison | |||
| "E722", # bare-except | |||
| "F821", # undefined-name | |||
| "F841", # unused-variable | |||
| "FURB113", # repeated-append | |||
| "FURB152", # math-constant | |||
| "UP007", # non-pep604-annotation | |||
| "UP032", # f-string | |||
| "UP045", # non-pep604-annotation-optional | |||
| "B005", # strip-with-multi-characters | |||
| "B006", # mutable-argument-default | |||
| "B007", # unused-loop-control-variable | |||
| "B026", # star-arg-unpacking-after-keyword-arg | |||
| "B903", # class-as-data-structure | |||
| "B904", # raise-without-from-inside-except | |||
| "B905", # zip-without-explicit-strict | |||
| "N806", # non-lowercase-variable-in-function | |||
| "N815", # mixed-case-variable-in-class-scope | |||
| "PT011", # pytest-raises-too-broad | |||
| "SIM102", # collapsible-if | |||
| "SIM103", # needless-bool | |||
| "SIM105", # suppressible-exception | |||
| "SIM107", # return-in-try-except-finally | |||
| "SIM108", # if-else-block-instead-of-if-exp | |||
| "SIM113", # enumerate-for-loop | |||
| "SIM117", # multiple-with-statements | |||
| "SIM210", # if-expr-with-true-false | |||
| "UP007", # non-pep604-annotation | |||
| "UP032", # f-string | |||
| "UP045", # non-pep604-annotation-optional | |||
| "B005", # strip-with-multi-characters | |||
| "B006", # mutable-argument-default | |||
| "B007", # unused-loop-control-variable | |||
| "B026", # star-arg-unpacking-after-keyword-arg | |||
| "B903", # class-as-data-structure | |||
| "B904", # raise-without-from-inside-except | |||
| "B905", # zip-without-explicit-strict | |||
| "N806", # non-lowercase-variable-in-function | |||
| "N815", # mixed-case-variable-in-class-scope | |||
| "PT011", # pytest-raises-too-broad | |||
| "SIM102", # collapsible-if | |||
| "SIM103", # needless-bool | |||
| "SIM105", # suppressible-exception | |||
| "SIM107", # return-in-try-except-finally | |||
| "SIM108", # if-else-block-instead-of-if-exp | |||
| "SIM113", # enumerate-for-loop | |||
| "SIM117", # multiple-with-statements | |||
| "SIM210", # if-expr-with-true-false | |||
| "UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/ | |||
| ] | |||
| [lint.per-file-ignores] | |||
| @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): | |||
| CURRENT_VERSION: str = Field( | |||
| description="Dify version", | |||
| default="1.4.3", | |||
| default="1.5.0", | |||
| ) | |||
| COMMIT_SHA: str = Field( | |||
| @@ -63,6 +63,7 @@ from .app import ( | |||
| statistic, | |||
| workflow, | |||
| workflow_app_log, | |||
| workflow_draft_variable, | |||
| workflow_run, | |||
| workflow_statistic, | |||
| ) | |||
| @@ -1,5 +1,6 @@ | |||
| import json | |||
| import logging | |||
| from collections.abc import Sequence | |||
| from typing import cast | |||
| from flask import abort, request | |||
| @@ -18,10 +19,12 @@ from controllers.console.app.error import ( | |||
| from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.file.models import File | |||
| from extensions.ext_database import db | |||
| from factories import variable_factory | |||
| from factories import file_factory, variable_factory | |||
| from fields.workflow_fields import workflow_fields, workflow_pagination_fields | |||
| from fields.workflow_run_fields import workflow_run_node_execution_fields | |||
| from libs import helper | |||
| @@ -30,6 +33,7 @@ from libs.login import current_user, login_required | |||
| from models import App | |||
| from models.account import Account | |||
| from models.model import AppMode | |||
| from models.workflow import Workflow | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.app import WorkflowHashNotEqualError | |||
| from services.errors.llm import InvokeRateLimitError | |||
| @@ -38,6 +42,24 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE | |||
| logger = logging.getLogger(__name__) | |||
| # TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing | |||
| # at the controller level rather than in the workflow logic. This would improve separation | |||
| # of concerns and make the code more maintainable. | |||
| def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]: | |||
| files = files or [] | |||
| file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) | |||
| file_objs: Sequence[File] = [] | |||
| if file_extra_config is None: | |||
| return file_objs | |||
| file_objs = file_factory.build_from_mappings( | |||
| mappings=files, | |||
| tenant_id=workflow.tenant_id, | |||
| config=file_extra_config, | |||
| ) | |||
| return file_objs | |||
| class DraftWorkflowApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -402,15 +424,30 @@ class DraftWorkflowNodeRunApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("query", type=str, required=False, location="json", default="") | |||
| parser.add_argument("files", type=list, location="json", default=[]) | |||
| args = parser.parse_args() | |||
| inputs = args.get("inputs") | |||
| if inputs == None: | |||
| user_inputs = args.get("inputs") | |||
| if user_inputs is None: | |||
| raise ValueError("missing inputs") | |||
| workflow_srv = WorkflowService() | |||
| # fetch draft workflow by app_model | |||
| draft_workflow = workflow_srv.get_draft_workflow(app_model=app_model) | |||
| if not draft_workflow: | |||
| raise ValueError("Workflow not initialized") | |||
| files = _parse_file(draft_workflow, args.get("files")) | |||
| workflow_service = WorkflowService() | |||
| workflow_node_execution = workflow_service.run_draft_workflow_node( | |||
| app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user | |||
| app_model=app_model, | |||
| draft_workflow=draft_workflow, | |||
| node_id=node_id, | |||
| user_inputs=user_inputs, | |||
| account=current_user, | |||
| query=args.get("query", ""), | |||
| files=files, | |||
| ) | |||
| return workflow_node_execution | |||
| @@ -731,6 +768,27 @@ class WorkflowByIdApi(Resource): | |||
| return None, 204 | |||
| class DraftWorkflowNodeLastRunApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) | |||
| @marshal_with(workflow_run_node_execution_fields) | |||
| def get(self, app_model: App, node_id: str): | |||
| srv = WorkflowService() | |||
| workflow = srv.get_draft_workflow(app_model) | |||
| if not workflow: | |||
| raise NotFound("Workflow not found") | |||
| node_exec = srv.get_node_last_run( | |||
| app_model=app_model, | |||
| workflow=workflow, | |||
| node_id=node_id, | |||
| ) | |||
| if node_exec is None: | |||
| raise NotFound("last run not found") | |||
| return node_exec | |||
| api.add_resource( | |||
| DraftWorkflowApi, | |||
| "/apps/<uuid:app_id>/workflows/draft", | |||
| @@ -795,3 +853,7 @@ api.add_resource( | |||
| WorkflowByIdApi, | |||
| "/apps/<uuid:app_id>/workflows/<string:workflow_id>", | |||
| ) | |||
| api.add_resource( | |||
| DraftWorkflowNodeLastRunApi, | |||
| "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run", | |||
| ) | |||
| @@ -0,0 +1,421 @@ | |||
| import logging | |||
| from typing import Any, NoReturn | |||
| from flask import Response | |||
| from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.app.error import ( | |||
| DraftWorkflowNotExist, | |||
| ) | |||
| from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from controllers.web.error import InvalidArgumentError, NotFoundError | |||
| from core.variables.segment_group import SegmentGroup | |||
| from core.variables.segments import ArrayFileSegment, FileSegment, Segment | |||
| from core.variables.types import SegmentType | |||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | |||
| from factories.file_factory import build_from_mapping, build_from_mappings | |||
| from factories.variable_factory import build_segment_with_type | |||
| from libs.login import current_user, login_required | |||
| from models import App, AppMode, db | |||
| from models.workflow import WorkflowDraftVariable | |||
| from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService | |||
| from services.workflow_service import WorkflowService | |||
| logger = logging.getLogger(__name__) | |||
| def _convert_values_to_json_serializable_object(value: Segment) -> Any: | |||
| if isinstance(value, FileSegment): | |||
| return value.value.model_dump() | |||
| elif isinstance(value, ArrayFileSegment): | |||
| return [i.model_dump() for i in value.value] | |||
| elif isinstance(value, SegmentGroup): | |||
| return [_convert_values_to_json_serializable_object(i) for i in value.value] | |||
| else: | |||
| return value.value | |||
| def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: | |||
| value = variable.get_value() | |||
| # create a copy of the value to avoid affecting the model cache. | |||
| value = value.model_copy(deep=True) | |||
| # Refresh the url signature before returning it to client. | |||
| if isinstance(value, FileSegment): | |||
| file = value.value | |||
| file.remote_url = file.generate_url() | |||
| elif isinstance(value, ArrayFileSegment): | |||
| files = value.value | |||
| for file in files: | |||
| file.remote_url = file.generate_url() | |||
| return _convert_values_to_json_serializable_object(value) | |||
| def _create_pagination_parser(): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "page", | |||
| type=inputs.int_range(1, 100_000), | |||
| required=False, | |||
| default=1, | |||
| location="args", | |||
| help="the page of data requested", | |||
| ) | |||
| parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") | |||
| return parser | |||
| _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { | |||
| "id": fields.String, | |||
| "type": fields.String(attribute=lambda model: model.get_variable_type()), | |||
| "name": fields.String, | |||
| "description": fields.String, | |||
| "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), | |||
| "value_type": fields.String, | |||
| "edited": fields.Boolean(attribute=lambda model: model.edited), | |||
| "visible": fields.Boolean, | |||
| } | |||
| _WORKFLOW_DRAFT_VARIABLE_FIELDS = dict( | |||
| _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, | |||
| value=fields.Raw(attribute=_serialize_var_value), | |||
| ) | |||
| _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { | |||
| "id": fields.String, | |||
| "type": fields.String(attribute=lambda _: "env"), | |||
| "name": fields.String, | |||
| "description": fields.String, | |||
| "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), | |||
| "value_type": fields.String, | |||
| "edited": fields.Boolean(attribute=lambda model: model.edited), | |||
| "visible": fields.Boolean, | |||
| } | |||
| _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = { | |||
| "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)), | |||
| } | |||
| def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: | |||
| return var_list.variables | |||
| _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = { | |||
| "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items), | |||
| "total": fields.Raw(), | |||
| } | |||
| _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = { | |||
| "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), | |||
| } | |||
| def _api_prerequisite(f): | |||
| """Common prerequisites for all draft workflow variable APIs. | |||
| It ensures the following conditions are satisfied: | |||
| - Dify has been property setup. | |||
| - The request user has logged in and initialized. | |||
| - The requested app is a workflow or a chat flow. | |||
| - The request user has the edit permission for the app. | |||
| """ | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) | |||
| def wrapper(*args, **kwargs): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| return f(*args, **kwargs) | |||
| return wrapper | |||
| class WorkflowVariableCollectionApi(Resource): | |||
| @_api_prerequisite | |||
| @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) | |||
| def get(self, app_model: App): | |||
| """ | |||
| Get draft workflow | |||
| """ | |||
| parser = _create_pagination_parser() | |||
| args = parser.parse_args() | |||
| # fetch draft workflow by app_model | |||
| workflow_service = WorkflowService() | |||
| workflow_exist = workflow_service.is_workflow_exist(app_model=app_model) | |||
| if not workflow_exist: | |||
| raise DraftWorkflowNotExist() | |||
| # fetch draft workflow by app_model | |||
| with Session(bind=db.engine, expire_on_commit=False) as session: | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=session, | |||
| ) | |||
| workflow_vars = draft_var_srv.list_variables_without_values( | |||
| app_id=app_model.id, | |||
| page=args.page, | |||
| limit=args.limit, | |||
| ) | |||
| return workflow_vars | |||
| @_api_prerequisite | |||
| def delete(self, app_model: App): | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=db.session(), | |||
| ) | |||
| draft_var_srv.delete_workflow_variables(app_model.id) | |||
| db.session.commit() | |||
| return Response("", 204) | |||
| def validate_node_id(node_id: str) -> NoReturn | None: | |||
| if node_id in [ | |||
| CONVERSATION_VARIABLE_NODE_ID, | |||
| SYSTEM_VARIABLE_NODE_ID, | |||
| ]: | |||
| # NOTE(QuantumGhost): While we store the system and conversation variables as node variables | |||
| # with specific `node_id` in database, we still want to make the API separated. By disallowing | |||
| # accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`, | |||
| # we mitigate the risk that user of the API depending on the implementation detail of the API. | |||
| # | |||
| # ref: [Hyrum's Law](https://www.hyrumslaw.com/) | |||
| raise InvalidArgumentError( | |||
| f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}", | |||
| ) | |||
| return None | |||
| class NodeVariableCollectionApi(Resource): | |||
| @_api_prerequisite | |||
| @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) | |||
| def get(self, app_model: App, node_id: str): | |||
| validate_node_id(node_id) | |||
| with Session(bind=db.engine, expire_on_commit=False) as session: | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=session, | |||
| ) | |||
| node_vars = draft_var_srv.list_node_variables(app_model.id, node_id) | |||
| return node_vars | |||
| @_api_prerequisite | |||
| def delete(self, app_model: App, node_id: str): | |||
| validate_node_id(node_id) | |||
| srv = WorkflowDraftVariableService(db.session()) | |||
| srv.delete_node_variables(app_model.id, node_id) | |||
| db.session.commit() | |||
| return Response("", 204) | |||
| class VariableApi(Resource): | |||
| _PATCH_NAME_FIELD = "name" | |||
| _PATCH_VALUE_FIELD = "value" | |||
| @_api_prerequisite | |||
| @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) | |||
| def get(self, app_model: App, variable_id: str): | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=db.session(), | |||
| ) | |||
| variable = draft_var_srv.get_variable(variable_id=variable_id) | |||
| if variable is None: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| if variable.app_id != app_model.id: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| return variable | |||
| @_api_prerequisite | |||
| @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) | |||
| def patch(self, app_model: App, variable_id: str): | |||
| # Request payload for file types: | |||
| # | |||
| # Local File: | |||
| # | |||
| # { | |||
| # "type": "image", | |||
| # "transfer_method": "local_file", | |||
| # "url": "", | |||
| # "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190" | |||
| # } | |||
| # | |||
| # Remote File: | |||
| # | |||
| # | |||
| # { | |||
| # "type": "image", | |||
| # "transfer_method": "remote_url", | |||
| # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=", | |||
| # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" | |||
| # } | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") | |||
| # Parse 'value' field as-is to maintain its original data structure | |||
| parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=db.session(), | |||
| ) | |||
| args = parser.parse_args(strict=True) | |||
| variable = draft_var_srv.get_variable(variable_id=variable_id) | |||
| if variable is None: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| if variable.app_id != app_model.id: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| new_name = args.get(self._PATCH_NAME_FIELD, None) | |||
| raw_value = args.get(self._PATCH_VALUE_FIELD, None) | |||
| if new_name is None and raw_value is None: | |||
| return variable | |||
| new_value = None | |||
| if raw_value is not None: | |||
| if variable.value_type == SegmentType.FILE: | |||
| if not isinstance(raw_value, dict): | |||
| raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") | |||
| raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id) | |||
| elif variable.value_type == SegmentType.ARRAY_FILE: | |||
| if not isinstance(raw_value, list): | |||
| raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") | |||
| if len(raw_value) > 0 and not isinstance(raw_value[0], dict): | |||
| raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") | |||
| raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id) | |||
| new_value = build_segment_with_type(variable.value_type, raw_value) | |||
| draft_var_srv.update_variable(variable, name=new_name, value=new_value) | |||
| db.session.commit() | |||
| return variable | |||
| @_api_prerequisite | |||
| def delete(self, app_model: App, variable_id: str): | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=db.session(), | |||
| ) | |||
| variable = draft_var_srv.get_variable(variable_id=variable_id) | |||
| if variable is None: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| if variable.app_id != app_model.id: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| draft_var_srv.delete_variable(variable) | |||
| db.session.commit() | |||
| return Response("", 204) | |||
| class VariableResetApi(Resource): | |||
| @_api_prerequisite | |||
| def put(self, app_model: App, variable_id: str): | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=db.session(), | |||
| ) | |||
| workflow_srv = WorkflowService() | |||
| draft_workflow = workflow_srv.get_draft_workflow(app_model) | |||
| if draft_workflow is None: | |||
| raise NotFoundError( | |||
| f"Draft workflow not found, app_id={app_model.id}", | |||
| ) | |||
| variable = draft_var_srv.get_variable(variable_id=variable_id) | |||
| if variable is None: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| if variable.app_id != app_model.id: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| resetted = draft_var_srv.reset_variable(draft_workflow, variable) | |||
| db.session.commit() | |||
| if resetted is None: | |||
| return Response("", 204) | |||
| else: | |||
| return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS) | |||
| def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: | |||
| with Session(bind=db.engine, expire_on_commit=False) as session: | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=session, | |||
| ) | |||
| if node_id == CONVERSATION_VARIABLE_NODE_ID: | |||
| draft_vars = draft_var_srv.list_conversation_variables(app_model.id) | |||
| elif node_id == SYSTEM_VARIABLE_NODE_ID: | |||
| draft_vars = draft_var_srv.list_system_variables(app_model.id) | |||
| else: | |||
| draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id) | |||
| return draft_vars | |||
| class ConversationVariableCollectionApi(Resource): | |||
| @_api_prerequisite | |||
| @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) | |||
| def get(self, app_model: App): | |||
| # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table | |||
| # so their IDs can be returned to the caller. | |||
| workflow_srv = WorkflowService() | |||
| draft_workflow = workflow_srv.get_draft_workflow(app_model) | |||
| if draft_workflow is None: | |||
| raise NotFoundError(description=f"draft workflow not found, id={app_model.id}") | |||
| draft_var_srv = WorkflowDraftVariableService(db.session()) | |||
| draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) | |||
| db.session.commit() | |||
| return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) | |||
| class SystemVariableCollectionApi(Resource): | |||
| @_api_prerequisite | |||
| @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) | |||
| def get(self, app_model: App): | |||
| return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID) | |||
| class EnvironmentVariableCollectionApi(Resource): | |||
| @_api_prerequisite | |||
| def get(self, app_model: App): | |||
| """ | |||
| Get draft workflow | |||
| """ | |||
| # fetch draft workflow by app_model | |||
| workflow_service = WorkflowService() | |||
| workflow = workflow_service.get_draft_workflow(app_model=app_model) | |||
| if workflow is None: | |||
| raise DraftWorkflowNotExist() | |||
| env_vars = workflow.environment_variables | |||
| env_vars_list = [] | |||
| for v in env_vars: | |||
| env_vars_list.append( | |||
| { | |||
| "id": v.id, | |||
| "type": "env", | |||
| "name": v.name, | |||
| "description": v.description, | |||
| "selector": v.selector, | |||
| "value_type": v.value_type.value, | |||
| "value": v.value, | |||
| # Do not track edited for env vars. | |||
| "edited": False, | |||
| "visible": True, | |||
| "editable": True, | |||
| } | |||
| ) | |||
| return {"items": env_vars_list} | |||
| api.add_resource( | |||
| WorkflowVariableCollectionApi, | |||
| "/apps/<uuid:app_id>/workflows/draft/variables", | |||
| ) | |||
| api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables") | |||
| api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>") | |||
| api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset") | |||
| api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables") | |||
| api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables") | |||
| api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables") | |||
| @@ -8,6 +8,15 @@ from libs.login import current_user | |||
| from models import App, AppMode | |||
| def _load_app_model(app_id: str) -> Optional[App]: | |||
| app_model = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| return app_model | |||
| def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): | |||
| def decorator(view_func): | |||
| @wraps(view_func) | |||
| @@ -20,11 +29,7 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ | |||
| del kwargs["app_id"] | |||
| app_model = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| app_model = _load_app_model(app_id) | |||
| if not app_model: | |||
| raise AppNotFoundError() | |||
| @@ -5,7 +5,7 @@ from typing import cast | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, fields, marshal, marshal_with, reqparse | |||
| from flask_restful import Resource, marshal, marshal_with, reqparse | |||
| from sqlalchemy import asc, desc, select | |||
| from werkzeug.exceptions import Forbidden, NotFound | |||
| @@ -239,12 +239,10 @@ class DatasetDocumentListApi(Resource): | |||
| return response | |||
| documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String} | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(documents_and_batch_fields) | |||
| @marshal_with(dataset_and_document_fields) | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def post(self, dataset_id): | |||
| @@ -290,6 +288,8 @@ class DatasetDocumentListApi(Resource): | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| @@ -297,7 +297,7 @@ class DatasetDocumentListApi(Resource): | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| return {"documents": documents, "batch": batch} | |||
| return {"dataset": dataset, "documents": documents, "batch": batch} | |||
| @setup_required | |||
| @login_required | |||
| @@ -85,6 +85,7 @@ class MemberInviteEmailApi(Resource): | |||
| return { | |||
| "result": "success", | |||
| "invitation_results": invitation_results, | |||
| "tenant_id": str(current_user.current_tenant.id), | |||
| }, 201 | |||
| @@ -110,7 +111,7 @@ class MemberCancelInviteApi(Resource): | |||
| except Exception as e: | |||
| raise ValueError(str(e)) | |||
| return {"result": "success"}, 204 | |||
| return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200 | |||
| class MemberUpdateRoleApi(Resource): | |||
| @@ -139,3 +139,13 @@ class InvokeRateLimitError(BaseHTTPException): | |||
| error_code = "rate_limit_error" | |||
| description = "Rate Limit Error" | |||
| code = 429 | |||
| class NotFoundError(BaseHTTPException): | |||
| error_code = "not_found" | |||
| code = 404 | |||
| class InvalidArgumentError(BaseHTTPException): | |||
| error_code = "invalid_param" | |||
| code = 400 | |||
| @@ -104,6 +104,7 @@ class VariableEntity(BaseModel): | |||
| Variable Entity. | |||
| """ | |||
| # `variable` records the name of the variable in user inputs. | |||
| variable: str | |||
| label: str | |||
| description: str = "" | |||
| @@ -29,13 +29,14 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from libs.flask_utils import preserve_flask_contexts | |||
| from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| from services.conversation_service import ConversationService | |||
| from services.errors.message import MessageNotExistsError | |||
| from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService | |||
| logger = logging.getLogger(__name__) | |||
| @@ -116,6 +117,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| ) | |||
| # parse files | |||
| # TODO(QuantumGhost): Move file parsing logic to the API controller layer | |||
| # for better separation of concerns. | |||
| # | |||
| # For implementation reference, see the `_parse_file` function and | |||
| # `DraftWorkflowNodeRunApi` class which handle this properly. | |||
| files = args["files"] if args.get("files") else [] | |||
| file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) | |||
| if file_extra_config: | |||
| @@ -261,6 +267,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| var_loader = DraftVarLoader( | |||
| engine=db.engine, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| ) | |||
| draft_var_srv = WorkflowDraftVariableService(db.session()) | |||
| draft_var_srv.prefill_conversation_variable_default_values(workflow) | |||
| return self._generate( | |||
| workflow=workflow, | |||
| @@ -271,6 +284,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| conversation=None, | |||
| stream=streaming, | |||
| variable_loader=var_loader, | |||
| ) | |||
| def single_loop_generate( | |||
| @@ -336,6 +350,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| var_loader = DraftVarLoader( | |||
| engine=db.engine, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| ) | |||
| draft_var_srv = WorkflowDraftVariableService(db.session()) | |||
| draft_var_srv.prefill_conversation_variable_default_values(workflow) | |||
| return self._generate( | |||
| workflow=workflow, | |||
| @@ -346,6 +367,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| conversation=None, | |||
| stream=streaming, | |||
| variable_loader=var_loader, | |||
| ) | |||
| def _generate( | |||
| @@ -359,6 +381,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| conversation: Optional[Conversation] = None, | |||
| stream: bool = True, | |||
| variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: | |||
| """ | |||
| Generate App response. | |||
| @@ -410,6 +433,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| "conversation_id": conversation.id, | |||
| "message_id": message.id, | |||
| "context": context, | |||
| "variable_loader": variable_loader, | |||
| }, | |||
| ) | |||
| @@ -438,6 +462,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| conversation_id: str, | |||
| message_id: str, | |||
| context: contextvars.Context, | |||
| variable_loader: VariableLoader, | |||
| ) -> None: | |||
| """ | |||
| Generate worker in a new thread. | |||
| @@ -454,8 +479,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| # get conversation and message | |||
| conversation = self._get_conversation(conversation_id) | |||
| message = self._get_message(message_id) | |||
| if message is None: | |||
| raise MessageNotExistsError("Message not exists") | |||
| # chatbot app | |||
| runner = AdvancedChatAppRunner( | |||
| @@ -464,6 +487,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| conversation=conversation, | |||
| message=message, | |||
| dialogue_count=self._dialogue_count, | |||
| variable_loader=variable_loader, | |||
| ) | |||
| runner.run() | |||
| @@ -19,6 +19,7 @@ from core.moderation.base import ModerationError | |||
| from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.variable_loader import VariableLoader | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models.enums import UserFrom | |||
| @@ -40,14 +41,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| conversation: Conversation, | |||
| message: Message, | |||
| dialogue_count: int, | |||
| variable_loader: VariableLoader, | |||
| ) -> None: | |||
| super().__init__(queue_manager) | |||
| super().__init__(queue_manager, variable_loader) | |||
| self.application_generate_entity = application_generate_entity | |||
| self.conversation = conversation | |||
| self.message = message | |||
| self._dialogue_count = dialogue_count | |||
| def _get_app_id(self) -> str: | |||
| return self.application_generate_entity.app_config.app_id | |||
| def run(self) -> None: | |||
| app_config = self.application_generate_entity.app_config | |||
| app_config = cast(AdvancedChatAppConfig, app_config) | |||
| @@ -26,7 +26,6 @@ from factories import file_factory | |||
| from libs.flask_utils import preserve_flask_contexts | |||
| from models import Account, App, EndUser | |||
| from services.conversation_service import ConversationService | |||
| from services.errors.message import MessageNotExistsError | |||
| logger = logging.getLogger(__name__) | |||
| @@ -124,6 +123,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| override_model_config_dict["retriever_resource"] = {"enabled": True} | |||
| # parse files | |||
| # TODO(QuantumGhost): Move file parsing logic to the API controller layer | |||
| # for better separation of concerns. | |||
| # | |||
| # For implementation reference, see the `_parse_file` function and | |||
| # `DraftWorkflowNodeRunApi` class which handle this properly. | |||
| files = args.get("files") or [] | |||
| file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) | |||
| if file_extra_config: | |||
| @@ -233,8 +237,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| # get conversation and message | |||
| conversation = self._get_conversation(conversation_id) | |||
| message = self._get_message(message_id) | |||
| if message is None: | |||
| raise MessageNotExistsError("Message not exists") | |||
| # chatbot app | |||
| runner = AgentChatAppRunner() | |||
| @@ -25,7 +25,6 @@ from factories import file_factory | |||
| from models.account import Account | |||
| from models.model import App, EndUser | |||
| from services.conversation_service import ConversationService | |||
| from services.errors.message import MessageNotExistsError | |||
| logger = logging.getLogger(__name__) | |||
| @@ -115,6 +114,11 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| override_model_config_dict["retriever_resource"] = {"enabled": True} | |||
| # parse files | |||
| # TODO(QuantumGhost): Move file parsing logic to the API controller layer | |||
| # for better separation of concerns. | |||
| # | |||
| # For implementation reference, see the `_parse_file` function and | |||
| # `DraftWorkflowNodeRunApi` class which handle this properly. | |||
| files = args["files"] if args.get("files") else [] | |||
| file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) | |||
| if file_extra_config: | |||
| @@ -219,8 +223,6 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| # get conversation and message | |||
| conversation = self._get_conversation(conversation_id) | |||
| message = self._get_message(message_id) | |||
| if message is None: | |||
| raise MessageNotExistsError("Message not exists") | |||
| # chatbot app | |||
| runner = ChatAppRunner() | |||
| @@ -48,6 +48,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecution | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||
| from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter | |||
| from models import ( | |||
| Account, | |||
| CreatorUserRole, | |||
| @@ -125,7 +126,7 @@ class WorkflowResponseConverter: | |||
| id=workflow_execution.id_, | |||
| workflow_id=workflow_execution.workflow_id, | |||
| status=workflow_execution.status, | |||
| outputs=workflow_execution.outputs, | |||
| outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs), | |||
| error=workflow_execution.error_message, | |||
| elapsed_time=workflow_execution.elapsed_time, | |||
| total_tokens=workflow_execution.total_tokens, | |||
| @@ -202,6 +203,8 @@ class WorkflowResponseConverter: | |||
| if not workflow_node_execution.finished_at: | |||
| return None | |||
| json_converter = WorkflowRuntimeTypeConverter() | |||
| return NodeFinishStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_node_execution.workflow_execution_id, | |||
| @@ -214,7 +217,7 @@ class WorkflowResponseConverter: | |||
| predecessor_node_id=workflow_node_execution.predecessor_node_id, | |||
| inputs=workflow_node_execution.inputs, | |||
| process_data=workflow_node_execution.process_data, | |||
| outputs=workflow_node_execution.outputs, | |||
| outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), | |||
| status=workflow_node_execution.status, | |||
| error=workflow_node_execution.error, | |||
| elapsed_time=workflow_node_execution.elapsed_time, | |||
| @@ -245,6 +248,8 @@ class WorkflowResponseConverter: | |||
| if not workflow_node_execution.finished_at: | |||
| return None | |||
| json_converter = WorkflowRuntimeTypeConverter() | |||
| return NodeRetryStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_node_execution.workflow_execution_id, | |||
| @@ -257,7 +262,7 @@ class WorkflowResponseConverter: | |||
| predecessor_node_id=workflow_node_execution.predecessor_node_id, | |||
| inputs=workflow_node_execution.inputs, | |||
| process_data=workflow_node_execution.process_data, | |||
| outputs=workflow_node_execution.outputs, | |||
| outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), | |||
| status=workflow_node_execution.status, | |||
| error=workflow_node_execution.error, | |||
| elapsed_time=workflow_node_execution.elapsed_time, | |||
| @@ -376,6 +381,7 @@ class WorkflowResponseConverter: | |||
| workflow_execution_id: str, | |||
| event: QueueIterationCompletedEvent, | |||
| ) -> IterationNodeCompletedStreamResponse: | |||
| json_converter = WorkflowRuntimeTypeConverter() | |||
| return IterationNodeCompletedStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_execution_id, | |||
| @@ -384,7 +390,7 @@ class WorkflowResponseConverter: | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| outputs=event.outputs, | |||
| outputs=json_converter.to_json_encodable(event.outputs), | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| inputs=event.inputs or {}, | |||
| @@ -463,7 +469,7 @@ class WorkflowResponseConverter: | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| outputs=event.outputs, | |||
| outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs), | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| inputs=event.inputs or {}, | |||
| @@ -101,6 +101,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| ) | |||
| # parse files | |||
| # TODO(QuantumGhost): Move file parsing logic to the API controller layer | |||
| # for better separation of concerns. | |||
| # | |||
| # For implementation reference, see the `_parse_file` function and | |||
| # `DraftWorkflowNodeRunApi` class which handle this properly. | |||
| files = args["files"] if args.get("files") else [] | |||
| file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) | |||
| if file_extra_config: | |||
| @@ -196,8 +201,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| try: | |||
| # get message | |||
| message = self._get_message(message_id) | |||
| if message is None: | |||
| raise MessageNotExistsError() | |||
| # chatbot app | |||
| runner = CompletionAppRunner() | |||
| @@ -29,6 +29,7 @@ from models.enums import CreatorUserRole | |||
| from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile | |||
| from services.errors.app_model_config import AppModelConfigBrokenError | |||
| from services.errors.conversation import ConversationNotExistsError | |||
| from services.errors.message import MessageNotExistsError | |||
| logger = logging.getLogger(__name__) | |||
| @@ -251,7 +252,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| return introduction or "" | |||
| def _get_conversation(self, conversation_id: str): | |||
| def _get_conversation(self, conversation_id: str) -> Conversation: | |||
| """ | |||
| Get conversation by conversation id | |||
| :param conversation_id: conversation id | |||
| @@ -260,11 +261,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() | |||
| if not conversation: | |||
| raise ConversationNotExistsError() | |||
| raise ConversationNotExistsError("Conversation not exists") | |||
| return conversation | |||
| def _get_message(self, message_id: str) -> Optional[Message]: | |||
| def _get_message(self, message_id: str) -> Message: | |||
| """ | |||
| Get message by message id | |||
| :param message_id: message id | |||
| @@ -272,4 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| """ | |||
| message = db.session.query(Message).filter(Message.id == message_id).first() | |||
| if message is None: | |||
| raise MessageNotExistsError("Message not exists") | |||
| return message | |||
| @@ -27,11 +27,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from libs.flask_utils import preserve_flask_contexts | |||
| from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService | |||
| logger = logging.getLogger(__name__) | |||
| @@ -94,6 +96,11 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| files: Sequence[Mapping[str, Any]] = args.get("files") or [] | |||
| # parse files | |||
| # TODO(QuantumGhost): Move file parsing logic to the API controller layer | |||
| # for better separation of concerns. | |||
| # | |||
| # For implementation reference, see the `_parse_file` function and | |||
| # `DraftWorkflowNodeRunApi` class which handle this properly. | |||
| file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) | |||
| system_files = file_factory.build_from_mappings( | |||
| mappings=files, | |||
| @@ -186,6 +193,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| streaming: bool = True, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | |||
| ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: | |||
| """ | |||
| Generate App response. | |||
| @@ -219,6 +227,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| "queue_manager": queue_manager, | |||
| "context": context, | |||
| "workflow_thread_pool_id": workflow_thread_pool_id, | |||
| "variable_loader": variable_loader, | |||
| }, | |||
| ) | |||
| @@ -303,6 +312,13 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| draft_var_srv = WorkflowDraftVariableService(db.session()) | |||
| draft_var_srv.prefill_conversation_variable_default_values(workflow) | |||
| var_loader = DraftVarLoader( | |||
| engine=db.engine, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| ) | |||
| return self._generate( | |||
| app_model=app_model, | |||
| @@ -313,6 +329,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| variable_loader=var_loader, | |||
| ) | |||
| def single_loop_generate( | |||
| @@ -379,7 +396,13 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| draft_var_srv = WorkflowDraftVariableService(db.session()) | |||
| draft_var_srv.prefill_conversation_variable_default_values(workflow) | |||
| var_loader = DraftVarLoader( | |||
| engine=db.engine, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| ) | |||
| return self._generate( | |||
| app_model=app_model, | |||
| workflow=workflow, | |||
| @@ -389,6 +412,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| variable_loader=var_loader, | |||
| ) | |||
| def _generate_worker( | |||
| @@ -397,6 +421,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| application_generate_entity: WorkflowAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| context: contextvars.Context, | |||
| variable_loader: VariableLoader, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> None: | |||
| """ | |||
| @@ -415,6 +440,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| variable_loader=variable_loader, | |||
| ) | |||
| runner.run() | |||
| @@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import ( | |||
| from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.variable_loader import VariableLoader | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models.enums import UserFrom | |||
| @@ -30,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| self, | |||
| application_generate_entity: WorkflowAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| variable_loader: VariableLoader, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> None: | |||
| """ | |||
| @@ -37,10 +39,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| :param queue_manager: application queue manager | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| super().__init__(queue_manager, variable_loader) | |||
| self.application_generate_entity = application_generate_entity | |||
| self.queue_manager = queue_manager | |||
| self.workflow_thread_pool_id = workflow_thread_pool_id | |||
| def _get_app_id(self) -> str: | |||
| return self.application_generate_entity.app_config.app_id | |||
| def run(self) -> None: | |||
| """ | |||
| Run application | |||
| @@ -1,6 +1,8 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional, cast | |||
| from sqlalchemy.orm import Session | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.apps.base_app_runner import AppRunner | |||
| from core.app.entities.queue_entities import ( | |||
| @@ -33,6 +35,7 @@ from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| AgentLogEvent, | |||
| BaseNodeEvent, | |||
| GraphEngineEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunPartialSucceededEvent, | |||
| @@ -62,15 +65,23 @@ from core.workflow.graph_engine.entities.event import ( | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | |||
| from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models.model import App | |||
| from models.workflow import Workflow | |||
| from services.workflow_draft_variable_service import ( | |||
| DraftVariableSaver, | |||
| ) | |||
| class WorkflowBasedAppRunner(AppRunner): | |||
| def __init__(self, queue_manager: AppQueueManager): | |||
| def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None: | |||
| self.queue_manager = queue_manager | |||
| self._variable_loader = variable_loader | |||
| def _get_app_id(self) -> str: | |||
| raise NotImplementedError("not implemented") | |||
| def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: | |||
| """ | |||
| @@ -173,6 +184,13 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| except NotImplementedError: | |||
| variable_mapping = {} | |||
| load_into_variable_pool( | |||
| variable_loader=self._variable_loader, | |||
| variable_pool=variable_pool, | |||
| variable_mapping=variable_mapping, | |||
| user_inputs=user_inputs, | |||
| ) | |||
| WorkflowEntry.mapping_user_inputs_to_variable_pool( | |||
| variable_mapping=variable_mapping, | |||
| user_inputs=user_inputs, | |||
| @@ -262,6 +280,12 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| ) | |||
| except NotImplementedError: | |||
| variable_mapping = {} | |||
| load_into_variable_pool( | |||
| self._variable_loader, | |||
| variable_pool=variable_pool, | |||
| variable_mapping=variable_mapping, | |||
| user_inputs=user_inputs, | |||
| ) | |||
| WorkflowEntry.mapping_user_inputs_to_variable_pool( | |||
| variable_mapping=variable_mapping, | |||
| @@ -376,6 +400,8 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| in_loop_id=event.in_loop_id, | |||
| ) | |||
| ) | |||
| self._save_draft_var_for_event(event) | |||
| elif isinstance(event, NodeRunFailedEvent): | |||
| self._publish_event( | |||
| QueueNodeFailedEvent( | |||
| @@ -438,6 +464,8 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| in_loop_id=event.in_loop_id, | |||
| ) | |||
| ) | |||
| self._save_draft_var_for_event(event) | |||
| elif isinstance(event, NodeInIterationFailedEvent): | |||
| self._publish_event( | |||
| QueueNodeInIterationFailedEvent( | |||
| @@ -690,3 +718,30 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| def _publish_event(self, event: AppQueueEvent) -> None: | |||
| self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) | |||
| def _save_draft_var_for_event(self, event: BaseNodeEvent): | |||
| run_result = event.route_node_state.node_run_result | |||
| if run_result is None: | |||
| return | |||
| process_data = run_result.process_data | |||
| outputs = run_result.outputs | |||
| with Session(bind=db.engine) as session, session.begin(): | |||
| draft_var_saver = DraftVariableSaver( | |||
| session=session, | |||
| app_id=self._get_app_id(), | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| # FIXME(QuantumGhost): rely on private state of queue_manager is not ideal. | |||
| invoke_from=self.queue_manager._invoke_from, | |||
| node_execution_id=event.id, | |||
| enclosing_node_id=event.in_loop_id or event.in_iteration_id or None, | |||
| ) | |||
| draft_var_saver.save(process_data=process_data, outputs=outputs) | |||
| def _remove_first_element_from_variable_string(key: str) -> str: | |||
| """ | |||
| Remove the first element from the prefix. | |||
| """ | |||
| prefix, remaining = key.split(".", maxsplit=1) | |||
| return remaining | |||
| @@ -17,9 +17,24 @@ class InvokeFrom(Enum): | |||
| Invoke From. | |||
| """ | |||
| # SERVICE_API indicates that this invocation is from an API call to Dify app. | |||
| # | |||
| # Description of service api in Dify docs: | |||
| # https://docs.dify.ai/en/guides/application-publishing/developing-with-apis | |||
| SERVICE_API = "service-api" | |||
| # WEB_APP indicates that this invocation is from | |||
| # the web app of the workflow (or chatflow). | |||
| # | |||
| # Description of web app in Dify docs: | |||
| # https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README | |||
| WEB_APP = "web-app" | |||
| # EXPLORE indicates that this invocation is from | |||
| # the workflow (or chatflow) explore page. | |||
| EXPLORE = "explore" | |||
| # DEBUGGER indicates that this invocation is from | |||
| # the workflow (or chatflow) edit page. | |||
| DEBUGGER = "debugger" | |||
| @classmethod | |||
| @@ -1 +1,11 @@ | |||
| from typing import Any | |||
| # TODO(QuantumGhost): Refactor variable type identification. Instead of directly | |||
| # comparing `dify_model_identity` with constants throughout the codebase, extract | |||
| # this logic into a dedicated function. This would encapsulate the implementation | |||
| # details of how different variable types are identified. | |||
| FILE_MODEL_IDENTITY = "__dify__file__" | |||
| def maybe_file_object(o: Any) -> bool: | |||
| return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY | |||
| @@ -534,7 +534,7 @@ class IndexingRunner: | |||
| # chunk nodes by chunk size | |||
| indexing_start_at = time.perf_counter() | |||
| tokens = 0 | |||
| if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: | |||
| if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": | |||
| # create keyword index | |||
| create_keyword_thread = threading.Thread( | |||
| target=self._process_keyword_index, | |||
| @@ -572,7 +572,7 @@ class IndexingRunner: | |||
| for future in futures: | |||
| tokens += future.result() | |||
| if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: | |||
| if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": | |||
| create_keyword_thread.join() | |||
| indexing_end_at = time.perf_counter() | |||
| @@ -1,3 +1,4 @@ | |||
| import binascii | |||
| from collections.abc import Mapping | |||
| from typing import Any | |||
| @@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient): | |||
| provider: str, | |||
| system_credentials: Mapping[str, Any], | |||
| ) -> PluginOAuthAuthorizationUrlResponse: | |||
| return self._request_with_plugin_daemon_response( | |||
| response = self._request_with_plugin_daemon_response_stream( | |||
| "POST", | |||
| f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", | |||
| PluginOAuthAuthorizationUrlResponse, | |||
| @@ -32,6 +33,9 @@ class OAuthHandler(BasePluginClient): | |||
| "Content-Type": "application/json", | |||
| }, | |||
| ) | |||
| for resp in response: | |||
| return resp | |||
| raise ValueError("No response received from plugin daemon for authorization URL request.") | |||
| def get_credentials( | |||
| self, | |||
| @@ -49,7 +53,7 @@ class OAuthHandler(BasePluginClient): | |||
| # encode request to raw http request | |||
| raw_request_bytes = self._convert_request_to_raw_data(request) | |||
| return self._request_with_plugin_daemon_response( | |||
| response = self._request_with_plugin_daemon_response_stream( | |||
| "POST", | |||
| f"plugin/{tenant_id}/dispatch/oauth/get_credentials", | |||
| PluginOAuthCredentialsResponse, | |||
| @@ -58,7 +62,8 @@ class OAuthHandler(BasePluginClient): | |||
| "data": { | |||
| "provider": provider, | |||
| "system_credentials": system_credentials, | |||
| "raw_request_bytes": raw_request_bytes, | |||
| # for json serialization | |||
| "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), | |||
| }, | |||
| }, | |||
| headers={ | |||
| @@ -66,6 +71,9 @@ class OAuthHandler(BasePluginClient): | |||
| "Content-Type": "application/json", | |||
| }, | |||
| ) | |||
| for resp in response: | |||
| return resp | |||
| raise ValueError("No response received from plugin daemon for authorization URL request.") | |||
| def _convert_request_to_raw_data(self, request: Request) -> bytes: | |||
| """ | |||
| @@ -79,7 +87,7 @@ class OAuthHandler(BasePluginClient): | |||
| """ | |||
| # Start with the request line | |||
| method = request.method | |||
| path = request.path | |||
| path = request.full_path | |||
| protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") | |||
| raw_data = f"{method} {path} {protocol}\r\n".encode() | |||
| @@ -76,6 +76,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): | |||
| if dataset.indexing_technique == "high_quality": | |||
| vector = Vector(dataset) | |||
| vector.create(documents) | |||
| with_keywords = False | |||
| if with_keywords: | |||
| keywords_list = kwargs.get("keywords_list") | |||
| keyword = Keyword(dataset) | |||
| @@ -91,6 +92,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): | |||
| vector.delete_by_ids(node_ids) | |||
| else: | |||
| vector.delete() | |||
| with_keywords = False | |||
| if with_keywords: | |||
| keyword = Keyword(dataset) | |||
| if node_ids: | |||
| @@ -16,6 +16,7 @@ from core.workflow.entities.workflow_execution import ( | |||
| WorkflowType, | |||
| ) | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter | |||
| from models import ( | |||
| Account, | |||
| CreatorUserRole, | |||
| @@ -152,7 +153,11 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): | |||
| db_model.version = domain_model.workflow_version | |||
| db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None | |||
| db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None | |||
| db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None | |||
| db_model.outputs = ( | |||
| json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs)) | |||
| if domain_model.outputs | |||
| else None | |||
| ) | |||
| db_model.status = domain_model.status | |||
| db_model.error = domain_model.error_message if domain_model.error_message else None | |||
| db_model.total_tokens = domain_model.total_tokens | |||
| @@ -19,6 +19,7 @@ from core.workflow.entities.workflow_node_execution import ( | |||
| ) | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter | |||
| from models import ( | |||
| Account, | |||
| CreatorUserRole, | |||
| @@ -146,6 +147,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| if not self._creator_user_role: | |||
| raise ValueError("created_by_role is required in repository constructor") | |||
| json_converter = WorkflowRuntimeTypeConverter() | |||
| db_model = WorkflowNodeExecutionModel() | |||
| db_model.id = domain_model.id | |||
| db_model.tenant_id = self._tenant_id | |||
| @@ -160,9 +162,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| db_model.node_id = domain_model.node_id | |||
| db_model.node_type = domain_model.node_type | |||
| db_model.title = domain_model.title | |||
| db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None | |||
| db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None | |||
| db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None | |||
| db_model.inputs = ( | |||
| json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None | |||
| ) | |||
| db_model.process_data = ( | |||
| json.dumps(json_converter.to_json_encodable(domain_model.process_data)) | |||
| if domain_model.process_data | |||
| else None | |||
| ) | |||
| db_model.outputs = ( | |||
| json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None | |||
| ) | |||
| db_model.status = domain_model.status | |||
| db_model.error = domain_model.error | |||
| db_model.elapsed_time = domain_model.elapsed_time | |||
| @@ -75,6 +75,20 @@ class StringSegment(Segment): | |||
| class FloatSegment(Segment): | |||
| value_type: SegmentType = SegmentType.NUMBER | |||
| value: float | |||
| # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. | |||
| # The following tests cannot pass. | |||
| # | |||
| # def test_float_segment_and_nan(): | |||
| # nan = float("nan") | |||
| # assert nan != nan | |||
| # | |||
| # f1 = FloatSegment(value=float("nan")) | |||
| # f2 = FloatSegment(value=float("nan")) | |||
| # assert f1 != f2 | |||
| # | |||
| # f3 = FloatSegment(value=nan) | |||
| # f4 = FloatSegment(value=nan) | |||
| # assert f3 != f4 | |||
| class IntegerSegment(Segment): | |||
| @@ -18,3 +18,17 @@ class SegmentType(StrEnum): | |||
| NONE = "none" | |||
| GROUP = "group" | |||
| def is_array_type(self): | |||
| return self in _ARRAY_TYPES | |||
| _ARRAY_TYPES = frozenset( | |||
| [ | |||
| SegmentType.ARRAY_ANY, | |||
| SegmentType.ARRAY_STRING, | |||
| SegmentType.ARRAY_NUMBER, | |||
| SegmentType.ARRAY_OBJECT, | |||
| SegmentType.ARRAY_FILE, | |||
| ] | |||
| ) | |||
| @@ -1,8 +1,26 @@ | |||
| import json | |||
| from collections.abc import Iterable, Sequence | |||
| from .segment_group import SegmentGroup | |||
| from .segments import ArrayFileSegment, FileSegment, Segment | |||
| def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: | |||
| selectors = [node_id, name] | |||
| if paths: | |||
| selectors.extend(paths) | |||
| return selectors | |||
| class SegmentJSONEncoder(json.JSONEncoder): | |||
| def default(self, o): | |||
| if isinstance(o, ArrayFileSegment): | |||
| return [v.model_dump() for v in o.value] | |||
| elif isinstance(o, FileSegment): | |||
| return o.value.model_dump() | |||
| elif isinstance(o, SegmentGroup): | |||
| return [self.default(seg) for seg in o.value] | |||
| elif isinstance(o, Segment): | |||
| return o.value | |||
| else: | |||
| super().default(o) | |||
| @@ -0,0 +1,39 @@ | |||
| import abc | |||
| from typing import Protocol | |||
| from core.variables import Variable | |||
| class ConversationVariableUpdater(Protocol): | |||
| """ | |||
| ConversationVariableUpdater defines an abstraction for updating conversation variable values. | |||
| It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating | |||
| conversation variables. | |||
| Implementations may choose to batch updates. If batching is used, the `flush` method | |||
| should be implemented to persist buffered changes, and `update` | |||
| should handle buffering accordingly. | |||
| Note: Since implementations may buffer updates, instances of ConversationVariableUpdater | |||
| are not thread-safe. Each VariableAssignerNode should create its own instance during execution. | |||
| """ | |||
| @abc.abstractmethod | |||
| def update(self, conversation_id: str, variable: "Variable") -> None: | |||
| """ | |||
| Updates the value of the specified conversation variable in the underlying storage. | |||
| :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. | |||
| :param variable: The `Variable` instance containing the updated value. | |||
| """ | |||
| pass | |||
| @abc.abstractmethod | |||
| def flush(self): | |||
| """ | |||
| Flushes all pending updates to the underlying storage system. | |||
| If the implementation does not buffer updates, this method can be a no-op. | |||
| """ | |||
| pass | |||
| @@ -7,12 +7,12 @@ from pydantic import BaseModel, Field | |||
| from core.file import File, FileAttribute, file_manager | |||
| from core.variables import Segment, SegmentGroup, Variable | |||
| from core.variables.consts import MIN_SELECTORS_LENGTH | |||
| from core.variables.segments import FileSegment, NoneSegment | |||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | |||
| from core.workflow.enums import SystemVariableKey | |||
| from factories import variable_factory | |||
| from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | |||
| from ..enums import SystemVariableKey | |||
| VariableValue = Union[str, int, float, dict, list, File] | |||
| VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") | |||
| @@ -30,9 +30,11 @@ class VariablePool(BaseModel): | |||
| # TODO: This user inputs is not used for pool. | |||
| user_inputs: Mapping[str, Any] = Field( | |||
| description="User inputs", | |||
| default_factory=dict, | |||
| ) | |||
| system_variables: Mapping[SystemVariableKey, Any] = Field( | |||
| description="System variables", | |||
| default_factory=dict, | |||
| ) | |||
| environment_variables: Sequence[Variable] = Field( | |||
| description="Environment variables.", | |||
| @@ -43,28 +45,7 @@ class VariablePool(BaseModel): | |||
| default_factory=list, | |||
| ) | |||
| def __init__( | |||
| self, | |||
| *, | |||
| system_variables: Mapping[SystemVariableKey, Any] | None = None, | |||
| user_inputs: Mapping[str, Any] | None = None, | |||
| environment_variables: Sequence[Variable] | None = None, | |||
| conversation_variables: Sequence[Variable] | None = None, | |||
| **kwargs, | |||
| ): | |||
| environment_variables = environment_variables or [] | |||
| conversation_variables = conversation_variables or [] | |||
| user_inputs = user_inputs or {} | |||
| system_variables = system_variables or {} | |||
| super().__init__( | |||
| system_variables=system_variables, | |||
| user_inputs=user_inputs, | |||
| environment_variables=environment_variables, | |||
| conversation_variables=conversation_variables, | |||
| **kwargs, | |||
| ) | |||
| def model_post_init(self, context: Any, /) -> None: | |||
| for key, value in self.system_variables.items(): | |||
| self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) | |||
| # Add environment variables to the variable pool | |||
| @@ -91,12 +72,12 @@ class VariablePool(BaseModel): | |||
| Returns: | |||
| None | |||
| """ | |||
| if len(selector) < 2: | |||
| if len(selector) < MIN_SELECTORS_LENGTH: | |||
| raise ValueError("Invalid selector") | |||
| if isinstance(value, Variable): | |||
| variable = value | |||
| if isinstance(value, Segment): | |||
| elif isinstance(value, Segment): | |||
| variable = variable_factory.segment_to_variable(segment=value, selector=selector) | |||
| else: | |||
| segment = variable_factory.build_segment(value) | |||
| @@ -118,7 +99,7 @@ class VariablePool(BaseModel): | |||
| Raises: | |||
| ValueError: If the selector is invalid. | |||
| """ | |||
| if len(selector) < 2: | |||
| if len(selector) < MIN_SELECTORS_LENGTH: | |||
| return None | |||
| hash_key = hash(tuple(selector[1:])) | |||
| @@ -66,6 +66,8 @@ class BaseNodeEvent(GraphEngineEvent): | |||
| """iteration id if node is in iteration""" | |||
| in_loop_id: Optional[str] = None | |||
| """loop id if node is in loop""" | |||
| # The version of the node, or "1" if not specified. | |||
| node_version: str = "1" | |||
| class NodeRunStartedEvent(BaseNodeEvent): | |||
| @@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor | |||
| from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle | |||
| from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent | |||
| from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | |||
| from core.workflow.utils import variable_utils | |||
| from libs.flask_utils import preserve_flask_contexts | |||
| from models.enums import UserFrom | |||
| from models.workflow import WorkflowType | |||
| @@ -314,6 +315,7 @@ class GraphEngine: | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| node_version=node_instance.version(), | |||
| ) | |||
| raise e | |||
| @@ -627,6 +629,7 @@ class GraphEngine: | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| agent_strategy=agent_strategy, | |||
| node_version=node_instance.version(), | |||
| ) | |||
| max_retries = node_instance.node_data.retry_config.max_retries | |||
| @@ -677,6 +680,7 @@ class GraphEngine: | |||
| error=run_result.error or "Unknown error", | |||
| retry_index=retries, | |||
| start_at=retry_start_at, | |||
| node_version=node_instance.version(), | |||
| ) | |||
| time.sleep(retry_interval) | |||
| break | |||
| @@ -712,6 +716,7 @@ class GraphEngine: | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| node_version=node_instance.version(), | |||
| ) | |||
| should_continue_retry = False | |||
| else: | |||
| @@ -726,6 +731,7 @@ class GraphEngine: | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| node_version=node_instance.version(), | |||
| ) | |||
| should_continue_retry = False | |||
| elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: | |||
| @@ -786,6 +792,7 @@ class GraphEngine: | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| node_version=node_instance.version(), | |||
| ) | |||
| should_continue_retry = False | |||
| @@ -803,6 +810,7 @@ class GraphEngine: | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| node_version=node_instance.version(), | |||
| ) | |||
| elif isinstance(event, RunRetrieverResourceEvent): | |||
| yield NodeRunRetrieverResourceEvent( | |||
| @@ -817,6 +825,7 @@ class GraphEngine: | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| node_version=node_instance.version(), | |||
| ) | |||
| except GenerateTaskStoppedError: | |||
| # trigger node run failed event | |||
| @@ -833,6 +842,7 @@ class GraphEngine: | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| node_version=node_instance.version(), | |||
| ) | |||
| return | |||
| except Exception as e: | |||
| @@ -847,16 +857,12 @@ class GraphEngine: | |||
| :param variable_value: variable value | |||
| :return: | |||
| """ | |||
| self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) | |||
| # if variable_value is a dict, then recursively append variables | |||
| if isinstance(variable_value, dict): | |||
| for key, value in variable_value.items(): | |||
| # construct new key list | |||
| new_key_list = variable_key_list + [key] | |||
| self._append_variables_recursively( | |||
| node_id=node_id, variable_key_list=new_key_list, variable_value=value | |||
| ) | |||
| variable_utils.append_variables_recursively( | |||
| self.graph_runtime_state.variable_pool, | |||
| node_id, | |||
| variable_key_list, | |||
| variable_value, | |||
| ) | |||
| def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: | |||
| """ | |||
| @@ -39,6 +39,10 @@ class AgentNode(ToolNode): | |||
| _node_data_cls = AgentNodeData # type: ignore | |||
| _node_type = NodeType.AGENT | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> Generator: | |||
| """ | |||
| Run the agent node | |||
| @@ -18,7 +18,11 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| class AnswerNode(BaseNode[AnswerNodeData]): | |||
| _node_data_cls = AnswerNodeData | |||
| _node_type: NodeType = NodeType.ANSWER | |||
| _node_type = NodeType.ANSWER | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| @@ -45,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]): | |||
| part = cast(TextGenerateRouteChunk, part) | |||
| answer += part.text | |||
| return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files}) | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| outputs={"answer": answer, "files": ArrayFileSegment(value=files)}, | |||
| ) | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| @@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor): | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| from_variable_selector=[answer_node_id, "answer"], | |||
| node_version=event.node_version, | |||
| ) | |||
| else: | |||
| route_chunk = cast(VarGenerateRouteChunk, route_chunk) | |||
| @@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor): | |||
| route_node_state=event.route_node_state, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| node_version=event.node_version, | |||
| ) | |||
| self.route_position[answer_node_id] += 1 | |||
| @@ -1,7 +1,7 @@ | |||
| import logging | |||
| from abc import abstractmethod | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast | |||
| from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| @@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) | |||
| class BaseNode(Generic[GenericNodeData]): | |||
| _node_data_cls: type[GenericNodeData] | |||
| _node_type: NodeType | |||
| _node_type: ClassVar[NodeType] | |||
| def __init__( | |||
| self, | |||
| @@ -90,8 +90,38 @@ class BaseNode(Generic[GenericNodeData]): | |||
| graph_config: Mapping[str, Any], | |||
| config: Mapping[str, Any], | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| """Extracts references variable selectors from node configuration. | |||
| The `config` parameter represents the configuration for a specific node type and corresponds | |||
| to the `data` field in the node definition object. | |||
| The returned mapping has the following structure: | |||
| {'1747829548239.#1747829667553.result#': ['1747829667553', 'result']} | |||
| For loop and iteration nodes, the mapping may look like this: | |||
| { | |||
| "1748332301644.input_selector": ["1748332363630", "result"], | |||
| "1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"], | |||
| } | |||
| where `1748332301644` is the ID of the loop / iteration node, | |||
| and `1748332325079` is the ID of the node inside the loop or iteration node. | |||
| Here, the key consists of two parts: the current node ID (provided as the `node_id` | |||
| parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector, | |||
| enclosed in `#` symbols. These two parts are separated by a dot (`.`). | |||
| The value is a list of string representing the variable selector, where the first element is the node ID | |||
| of the referenced variable, and the second element is the variable name within that node. | |||
| The meaning of the above response is: | |||
| The node with ID `1747829548239` references the variable `result` from the node with | |||
| ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a | |||
| reference to the `result` output variable of node `1747829667553`. | |||
| :param graph_config: graph config | |||
| :param config: node config | |||
| :return: | |||
| @@ -101,9 +131,10 @@ class BaseNode(Generic[GenericNodeData]): | |||
| raise ValueError("Node ID is required when extracting variable selector to variable mapping.") | |||
| node_data = cls._node_data_cls(**config.get("data", {})) | |||
| return cls._extract_variable_selector_to_variable_mapping( | |||
| data = cls._extract_variable_selector_to_variable_mapping( | |||
| graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) | |||
| ) | |||
| return data | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| @@ -139,6 +170,16 @@ class BaseNode(Generic[GenericNodeData]): | |||
| """ | |||
| return self._node_type | |||
| @classmethod | |||
| @abstractmethod | |||
| def version(cls) -> str: | |||
| """`node_version` returns the version of current node type.""" | |||
| # NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`. | |||
| # | |||
| # If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING` | |||
| # in `api/core/workflow/nodes/__init__.py`. | |||
| raise NotImplementedError("subclasses of BaseNode must implement `version` method.") | |||
| @property | |||
| def should_continue_on_error(self) -> bool: | |||
| """judge if should continue on error | |||
| @@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]): | |||
| return code_provider.get_default_config() | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: | |||
| # Get code language | |||
| code_language = self.node_data.code_language | |||
| @@ -126,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]): | |||
| prefix: str = "", | |||
| depth: int = 1, | |||
| ): | |||
| # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. | |||
| # Note that `_transform_result` may produce lists containing `None` values, | |||
| # which don't conform to the type requirements of `Array*Segment` classes. | |||
| if depth > dify_config.CODE_MAX_DEPTH: | |||
| raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") | |||
| @@ -24,7 +24,7 @@ from configs import dify_config | |||
| from core.file import File, FileTransferMethod, file_manager | |||
| from core.helper import ssrf_proxy | |||
| from core.variables import ArrayFileSegment | |||
| from core.variables.segments import FileSegment | |||
| from core.variables.segments import ArrayStringSegment, FileSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| @@ -45,6 +45,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): | |||
| _node_data_cls = DocumentExtractorNodeData | |||
| _node_type = NodeType.DOCUMENT_EXTRACTOR | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self): | |||
| variable_selector = self.node_data.variable_selector | |||
| variable = self.graph_runtime_state.variable_pool.get(variable_selector) | |||
| @@ -67,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=inputs, | |||
| process_data=process_data, | |||
| outputs={"text": extracted_text_list}, | |||
| outputs={"text": ArrayStringSegment(value=extracted_text_list)}, | |||
| ) | |||
| elif isinstance(value, File): | |||
| extracted_text = _extract_text_from_file(value) | |||
| @@ -447,7 +451,7 @@ def _extract_text_from_excel(file_content: bytes) -> str: | |||
| df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore | |||
| # Combine multi-line text in column names into a single line | |||
| df.columns = pd.Index([" ".join(col.splitlines()) for col in df.columns]) | |||
| df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns]) | |||
| # Manually construct the Markdown table | |||
| markdown_table += _construct_markdown_table(df) + "\n\n" | |||
| @@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]): | |||
| _node_data_cls = EndNodeData | |||
| _node_type = NodeType.END | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run node | |||
| @@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor): | |||
| route_node_state=event.route_node_state, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| node_version=event.node_version, | |||
| ) | |||
| self.route_position[end_node_id] += 1 | |||
| @@ -6,6 +6,7 @@ from typing import Any, Optional | |||
| from configs import dify_config | |||
| from core.file import File, FileTransferMethod | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| from core.variables.segments import ArrayFileSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_entities import VariableSelector | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| @@ -60,6 +61,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): | |||
| }, | |||
| } | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: | |||
| process_data = {} | |||
| try: | |||
| @@ -92,7 +97,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| outputs={ | |||
| "status_code": response.status_code, | |||
| "body": response.text if not files else "", | |||
| "body": response.text if not files.value else "", | |||
| "headers": response.headers, | |||
| "files": files, | |||
| }, | |||
| @@ -166,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): | |||
| return mapping | |||
| def extract_files(self, url: str, response: Response) -> list[File]: | |||
| def extract_files(self, url: str, response: Response) -> ArrayFileSegment: | |||
| """ | |||
| Extract files from response by checking both Content-Type header and URL | |||
| """ | |||
| @@ -178,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): | |||
| content_disposition_type = None | |||
| if not is_file: | |||
| return files | |||
| return ArrayFileSegment(value=[]) | |||
| if parsed_content_disposition: | |||
| content_disposition_filename = parsed_content_disposition.get_filename() | |||
| @@ -211,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): | |||
| ) | |||
| files.append(file) | |||
| return files | |||
| return ArrayFileSegment(value=files) | |||
| @@ -1,4 +1,5 @@ | |||
| from typing import Literal | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Literal | |||
| from typing_extensions import deprecated | |||
| @@ -16,6 +17,10 @@ class IfElseNode(BaseNode[IfElseNodeData]): | |||
| _node_data_cls = IfElseNodeData | |||
| _node_type = NodeType.IF_ELSE | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run node | |||
| @@ -87,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]): | |||
| return data | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| *, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: IfElseNodeData, | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| var_mapping: dict[str, list[str]] = {} | |||
| for case in node_data.cases or []: | |||
| for condition in case.conditions: | |||
| key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) | |||
| var_mapping[key] = condition.variable_selector | |||
| return var_mapping | |||
| @deprecated("This function is deprecated. You should use the new cases structure.") | |||
| def _should_not_use_old_function( | |||
| @@ -11,6 +11,7 @@ from flask import Flask, current_app | |||
| from configs import dify_config | |||
| from core.variables import ArrayVariable, IntegerVariable, NoneVariable | |||
| from core.variables.segments import ArrayAnySegment, ArraySegment | |||
| from core.workflow.entities.node_entities import ( | |||
| NodeRunResult, | |||
| ) | |||
| @@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.event import NodeEvent, RunCompletedEvent | |||
| from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData | |||
| from factories.variable_factory import build_segment | |||
| from libs.flask_utils import preserve_flask_contexts | |||
| from .exc import ( | |||
| @@ -72,6 +74,10 @@ class IterationNode(BaseNode[IterationNodeData]): | |||
| }, | |||
| } | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: | |||
| """ | |||
| Run the node. | |||
| @@ -85,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]): | |||
| raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") | |||
| if isinstance(variable, NoneVariable) or len(variable.value) == 0: | |||
| # Try our best to preserve the type informat. | |||
| if isinstance(variable, ArraySegment): | |||
| output = variable.model_copy(update={"value": []}) | |||
| else: | |||
| output = ArrayAnySegment(value=[]) | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| outputs={"output": []}, | |||
| # TODO(QuantumGhost): is it possible to compute the type of `output` | |||
| # from graph definition? | |||
| outputs={"output": output}, | |||
| ) | |||
| ) | |||
| return | |||
| @@ -231,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]): | |||
| # Flatten the list of lists | |||
| if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs): | |||
| outputs = [item for sublist in outputs for item in sublist] | |||
| output_segment = build_segment(outputs) | |||
| yield IterationRunSucceededEvent( | |||
| iteration_id=self.id, | |||
| @@ -247,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]): | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| outputs={"output": outputs}, | |||
| outputs={"output": output_segment}, | |||
| metadata={ | |||
| WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, | |||
| @@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]): | |||
| _node_data_cls = IterationStartNodeData | |||
| _node_type = NodeType.ITERATION_START | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run the node. | |||
| @@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.variables import StringSegment | |||
| from core.variables.segments import ArrayObjectSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.enums import NodeType | |||
| @@ -70,6 +71,10 @@ class KnowledgeRetrievalNode(LLMNode): | |||
| _node_data_cls = KnowledgeRetrievalNodeData # type: ignore | |||
| _node_type = NodeType.KNOWLEDGE_RETRIEVAL | |||
| @classmethod | |||
| def version(cls): | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: # type: ignore | |||
| node_data = cast(KnowledgeRetrievalNodeData, self.node_data) | |||
| # extract variables | |||
| @@ -115,9 +120,12 @@ class KnowledgeRetrievalNode(LLMNode): | |||
| # retrieve knowledge | |||
| try: | |||
| results = self._fetch_dataset_retriever(node_data=node_data, query=query) | |||
| outputs = {"result": results} | |||
| outputs = {"result": ArrayObjectSegment(value=results)} | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=variables, | |||
| process_data=None, | |||
| outputs=outputs, # type: ignore | |||
| ) | |||
| except KnowledgeRetrievalNodeError as e: | |||
| @@ -3,6 +3,7 @@ from typing import Any, Literal, Union | |||
| from core.file import File | |||
| from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment | |||
| from core.variables.segments import ArrayAnySegment, ArraySegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| @@ -16,6 +17,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): | |||
| _node_data_cls = ListOperatorNodeData | |||
| _node_type = NodeType.LIST_OPERATOR | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self): | |||
| inputs: dict[str, list] = {} | |||
| process_data: dict[str, list] = {} | |||
| @@ -30,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): | |||
| if not variable.value: | |||
| inputs = {"variable": []} | |||
| process_data = {"variable": []} | |||
| outputs = {"result": [], "first_record": None, "last_record": None} | |||
| if isinstance(variable, ArraySegment): | |||
| result = variable.model_copy(update={"value": []}) | |||
| else: | |||
| result = ArrayAnySegment(value=[]) | |||
| outputs = {"result": result, "first_record": None, "last_record": None} | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=inputs, | |||
| @@ -71,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): | |||
| variable = self._apply_slice(variable) | |||
| outputs = { | |||
| "result": variable.value, | |||
| "result": variable, | |||
| "first_record": variable.value[0] if variable.value else None, | |||
| "last_record": variable.value[-1] if variable.value else None, | |||
| } | |||
| @@ -119,9 +119,6 @@ class FileSaverImpl(LLMFileSaver): | |||
| size=len(data), | |||
| related_id=tool_file.id, | |||
| url=url, | |||
| # TODO(QuantumGhost): how should I set the following key? | |||
| # What's the difference between `remote_url` and `url`? | |||
| # What's the purpose of `storage_key` and `dify_model_identity`? | |||
| storage_key=tool_file.file_key, | |||
| ) | |||
| @@ -138,6 +138,10 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| ) | |||
| self._llm_file_saver = llm_file_saver | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: | |||
| def process_structured_output(text: str) -> Optional[dict[str, Any]]: | |||
| """Process structured output if enabled""" | |||
| @@ -255,7 +259,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| if structured_output: | |||
| outputs["structured_output"] = structured_output | |||
| if self._file_outputs is not None: | |||
| outputs["files"] = self._file_outputs | |||
| outputs["files"] = ArrayFileSegment(value=self._file_outputs) | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| @@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]): | |||
| _node_data_cls = LoopEndNodeData | |||
| _node_type = NodeType.LOOP_END | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run the node. | |||
| @@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| _node_data_cls = LoopNodeData | |||
| _node_type = NodeType.LOOP | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: | |||
| """Run the node.""" | |||
| # Get inputs | |||
| @@ -482,6 +486,13 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| variable_mapping.update(sub_node_variable_mapping) | |||
| for loop_variable in node_data.loop_variables or []: | |||
| if loop_variable.value_type == "variable": | |||
| assert loop_variable.value is not None, "Loop variable value must be provided for variable type" | |||
| # add loop variable to variable mapping | |||
| selector = loop_variable.value | |||
| variable_mapping[f"{node_id}.{loop_variable.label}"] = selector | |||
| # remove variable out from loop | |||
| variable_mapping = { | |||
| key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids | |||
| @@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]): | |||
| _node_data_cls = LoopStartNodeData | |||
| _node_type = NodeType.LOOP_START | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: | |||
| """ | |||
| Run the node. | |||
| @@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var | |||
| LATEST_VERSION = "latest" | |||
| # NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode. | |||
| # Specifically, if you have introduced new node types, you should add them here. | |||
| # | |||
| # TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__` | |||
| # hook. Try to avoid duplication of node information. | |||
| NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { | |||
| NodeType.START: { | |||
| LATEST_VERSION: StartNode, | |||
| @@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData | |||
| from core.workflow.nodes.llm import ModelConfig, VisionConfig | |||
| class _ParameterConfigError(Exception): | |||
| pass | |||
| class ParameterConfig(BaseModel): | |||
| """ | |||
| Parameter Config. | |||
| @@ -27,6 +31,19 @@ class ParameterConfig(BaseModel): | |||
| raise ValueError("Invalid parameter name, __reason and __is_success are reserved") | |||
| return str(value) | |||
| def is_array_type(self) -> bool: | |||
| return self.type in ("array[string]", "array[number]", "array[object]") | |||
| def element_type(self) -> Literal["string", "number", "object"]: | |||
| if self.type == "array[number]": | |||
| return "number" | |||
| elif self.type == "array[string]": | |||
| return "string" | |||
| elif self.type == "array[object]": | |||
| return "object" | |||
| else: | |||
| raise _ParameterConfigError(f"{self.type} is not array type.") | |||
| class ParameterExtractorNodeData(BaseNodeData): | |||
| """ | |||
| @@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate | |||
| from core.prompt.simple_prompt_transform import ModelMode | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.variables.types import SegmentType | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| @@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.llm import ModelConfig, llm_utils | |||
| from core.workflow.utils import variable_template_parser | |||
| from factories.variable_factory import build_segment_with_type | |||
| from .entities import ParameterExtractorNodeData | |||
| from .exc import ( | |||
| @@ -109,6 +111,10 @@ class ParameterExtractorNode(BaseNode): | |||
| } | |||
| } | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self): | |||
| """ | |||
| Run the node. | |||
| @@ -584,28 +590,30 @@ class ParameterExtractorNode(BaseNode): | |||
| elif parameter.type in {"string", "select"}: | |||
| if isinstance(result[parameter.name], str): | |||
| transformed_result[parameter.name] = result[parameter.name] | |||
| elif parameter.type.startswith("array"): | |||
| elif parameter.is_array_type(): | |||
| if isinstance(result[parameter.name], list): | |||
| nested_type = parameter.type[6:-1] | |||
| transformed_result[parameter.name] = [] | |||
| nested_type = parameter.element_type() | |||
| assert nested_type is not None | |||
| segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) | |||
| transformed_result[parameter.name] = segment_value | |||
| for item in result[parameter.name]: | |||
| if nested_type == "number": | |||
| if isinstance(item, int | float): | |||
| transformed_result[parameter.name].append(item) | |||
| segment_value.value.append(item) | |||
| elif isinstance(item, str): | |||
| try: | |||
| if "." in item: | |||
| transformed_result[parameter.name].append(float(item)) | |||
| segment_value.value.append(float(item)) | |||
| else: | |||
| transformed_result[parameter.name].append(int(item)) | |||
| segment_value.value.append(int(item)) | |||
| except ValueError: | |||
| pass | |||
| elif nested_type == "string": | |||
| if isinstance(item, str): | |||
| transformed_result[parameter.name].append(item) | |||
| segment_value.value.append(item) | |||
| elif nested_type == "object": | |||
| if isinstance(item, dict): | |||
| transformed_result[parameter.name].append(item) | |||
| segment_value.value.append(item) | |||
| if parameter.name not in transformed_result: | |||
| if parameter.type == "number": | |||
| @@ -615,7 +623,9 @@ class ParameterExtractorNode(BaseNode): | |||
| elif parameter.type in {"string", "select"}: | |||
| transformed_result[parameter.name] = "" | |||
| elif parameter.type.startswith("array"): | |||
| transformed_result[parameter.name] = [] | |||
| transformed_result[parameter.name] = build_segment_with_type( | |||
| segment_type=SegmentType(parameter.type), value=[] | |||
| ) | |||
| return transformed_result | |||
| @@ -40,6 +40,10 @@ class QuestionClassifierNode(LLMNode): | |||
| _node_data_cls = QuestionClassifierNodeData # type: ignore | |||
| _node_type = NodeType.QUESTION_CLASSIFIER | |||
| @classmethod | |||
| def version(cls): | |||
| return "1" | |||
| def _run(self): | |||
| node_data = cast(QuestionClassifierNodeData, self.node_data) | |||
| variable_pool = self.graph_runtime_state.variable_pool | |||
| @@ -10,6 +10,10 @@ class StartNode(BaseNode[StartNodeData]): | |||
| _node_data_cls = StartNodeData | |||
| _node_type = NodeType.START | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: | |||
| node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) | |||
| system_inputs = self.graph_runtime_state.variable_pool.system_variables | |||
| @@ -18,5 +22,6 @@ class StartNode(BaseNode[StartNodeData]): | |||
| # Set system variables as node outputs. | |||
| for var in system_inputs: | |||
| node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] | |||
| outputs = dict(node_inputs) | |||
| return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs) | |||
| return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) | |||
| @@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): | |||
| "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, | |||
| } | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: | |||
| # Get variables | |||
| variables = {} | |||
| @@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter | |||
| from core.tools.errors import ToolInvokeError | |||
| from core.tools.tool_engine import ToolEngine | |||
| from core.tools.utils.message_transformer import ToolFileMessageTransformer | |||
| from core.variables.segments import ArrayAnySegment | |||
| from core.variables.segments import ArrayAnySegment, ArrayFileSegment | |||
| from core.variables.variables import ArrayAnyVariable | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| @@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| _node_data_cls = ToolNodeData | |||
| _node_type = NodeType.TOOL | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> Generator: | |||
| """ | |||
| Run the tool node | |||
| @@ -300,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| variables[variable_name] = variable_value | |||
| elif message.type == ToolInvokeMessage.MessageType.FILE: | |||
| assert message.meta is not None | |||
| assert isinstance(message.meta, File) | |||
| files.append(message.meta["file"]) | |||
| elif message.type == ToolInvokeMessage.MessageType.LOG: | |||
| assert isinstance(message.message, ToolInvokeMessage.LogMessage) | |||
| @@ -363,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| outputs={"text": text, "files": files, "json": json, **variables}, | |||
| outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables}, | |||
| metadata={ | |||
| **agent_execution_metadata, | |||
| WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, | |||
| @@ -1,3 +1,6 @@ | |||
| from collections.abc import Mapping | |||
| from core.variables.segments import Segment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| @@ -9,16 +12,20 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): | |||
| _node_data_cls = VariableAssignerNodeData | |||
| _node_type = NodeType.VARIABLE_AGGREGATOR | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| def _run(self) -> NodeRunResult: | |||
| # Get variables | |||
| outputs = {} | |||
| outputs: dict[str, Segment | Mapping[str, Segment]] = {} | |||
| inputs = {} | |||
| if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: | |||
| for selector in self.node_data.variables: | |||
| variable = self.graph_runtime_state.variable_pool.get(selector) | |||
| if variable is not None: | |||
| outputs = {"output": variable.to_object()} | |||
| outputs = {"output": variable} | |||
| inputs = {".".join(selector[1:]): variable.to_object()} | |||
| break | |||
| @@ -28,7 +35,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): | |||
| variable = self.graph_runtime_state.variable_pool.get(selector) | |||
| if variable is not None: | |||
| outputs[group.group_name] = {"output": variable.to_object()} | |||
| outputs[group.group_name] = {"output": variable} | |||
| inputs[".".join(selector[1:])] = variable.to_object() | |||
| break | |||
| @@ -1,19 +1,55 @@ | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from collections.abc import Mapping, MutableMapping, Sequence | |||
| from typing import Any, TypeVar | |||
| from core.variables import Variable | |||
| from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError | |||
| from extensions.ext_database import db | |||
| from models import ConversationVariable | |||
| from pydantic import BaseModel | |||
| from core.variables import Segment | |||
| from core.variables.consts import MIN_SELECTORS_LENGTH | |||
| from core.variables.types import SegmentType | |||
| def update_conversation_variable(conversation_id: str, variable: Variable): | |||
| stmt = select(ConversationVariable).where( | |||
| ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id | |||
| # Use double underscore (`__`) prefix for internal variables | |||
| # to minimize risk of collision with user-defined variable names. | |||
| _UPDATED_VARIABLES_KEY = "__updated_variables" | |||
| class UpdatedVariable(BaseModel): | |||
| name: str | |||
| selector: Sequence[str] | |||
| value_type: SegmentType | |||
| new_value: Any | |||
| _T = TypeVar("_T", bound=MutableMapping[str, Any]) | |||
| def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: | |||
| if len(selector) < MIN_SELECTORS_LENGTH: | |||
| raise Exception("selector too short") | |||
| node_id, var_name = selector[:2] | |||
| return UpdatedVariable( | |||
| name=var_name, | |||
| selector=list(selector[:2]), | |||
| value_type=seg.value_type, | |||
| new_value=seg.value, | |||
| ) | |||
| with Session(db.engine) as session: | |||
| row = session.scalar(stmt) | |||
| if not row: | |||
| raise VariableOperatorNodeError("conversation variable not found in the database") | |||
| row.data = variable.model_dump_json() | |||
| session.commit() | |||
| def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T: | |||
| m[_UPDATED_VARIABLES_KEY] = updates | |||
| return m | |||
| def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None: | |||
| updated_values = m.get(_UPDATED_VARIABLES_KEY, None) | |||
| if updated_values is None: | |||
| return None | |||
| result = [] | |||
| for items in updated_values: | |||
| if isinstance(items, UpdatedVariable): | |||
| result.append(items) | |||
| elif isinstance(items, dict): | |||
| items = UpdatedVariable.model_validate(items) | |||
| result.append(items) | |||
| else: | |||
| raise TypeError(f"Invalid updated variable: {items}, type={type(items)}") | |||
| return result | |||
| @@ -0,0 +1,38 @@ | |||
| from sqlalchemy import Engine, select | |||
| from sqlalchemy.orm import Session | |||
| from core.variables.variables import Variable | |||
| from models.engine import db | |||
| from models.workflow import ConversationVariable | |||
| from .exc import VariableOperatorNodeError | |||
| class ConversationVariableUpdaterImpl: | |||
| _engine: Engine | None | |||
| def __init__(self, engine: Engine | None = None) -> None: | |||
| self._engine = engine | |||
| def _get_engine(self) -> Engine: | |||
| if self._engine: | |||
| return self._engine | |||
| return db.engine | |||
| def update(self, conversation_id: str, variable: Variable): | |||
| stmt = select(ConversationVariable).where( | |||
| ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id | |||
| ) | |||
| with Session(self._get_engine()) as session: | |||
| row = session.scalar(stmt) | |||
| if not row: | |||
| raise VariableOperatorNodeError("conversation variable not found in the database") | |||
| row.data = variable.model_dump_json() | |||
| session.commit() | |||
| def flush(self): | |||
| pass | |||
| def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl: | |||
| return ConversationVariableUpdaterImpl() | |||
| @@ -1,4 +1,9 @@ | |||
| from collections.abc import Callable, Mapping, Sequence | |||
| from typing import TYPE_CHECKING, Any, Optional, TypeAlias | |||
| from core.variables import SegmentType, Variable | |||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID | |||
| from core.workflow.conversation_variable_updater import ConversationVariableUpdater | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| @@ -7,16 +12,71 @@ from core.workflow.nodes.variable_assigner.common import helpers as common_helpe | |||
| from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError | |||
| from factories import variable_factory | |||
| from ..common.impl import conversation_variable_updater_factory | |||
| from .node_data import VariableAssignerData, WriteMode | |||
| if TYPE_CHECKING: | |||
| from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState | |||
| _CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] | |||
| class VariableAssignerNode(BaseNode[VariableAssignerData]): | |||
| _node_data_cls = VariableAssignerData | |||
| _node_type = NodeType.VARIABLE_ASSIGNER | |||
| _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY | |||
| def __init__( | |||
| self, | |||
| id: str, | |||
| config: Mapping[str, Any], | |||
| graph_init_params: "GraphInitParams", | |||
| graph: "Graph", | |||
| graph_runtime_state: "GraphRuntimeState", | |||
| previous_node_id: Optional[str] = None, | |||
| thread_pool_id: Optional[str] = None, | |||
| conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, | |||
| ) -> None: | |||
| super().__init__( | |||
| id=id, | |||
| config=config, | |||
| graph_init_params=graph_init_params, | |||
| graph=graph, | |||
| graph_runtime_state=graph_runtime_state, | |||
| previous_node_id=previous_node_id, | |||
| thread_pool_id=thread_pool_id, | |||
| ) | |||
| self._conv_var_updater_factory = conv_var_updater_factory | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "1" | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| *, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: VariableAssignerData, | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| mapping = {} | |||
| assigned_variable_node_id = node_data.assigned_variable_selector[0] | |||
| if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: | |||
| selector_key = ".".join(node_data.assigned_variable_selector) | |||
| key = f"{node_id}.#{selector_key}#" | |||
| mapping[key] = node_data.assigned_variable_selector | |||
| selector_key = ".".join(node_data.input_variable_selector) | |||
| key = f"{node_id}.#{selector_key}#" | |||
| mapping[key] = node_data.input_variable_selector | |||
| return mapping | |||
| def _run(self) -> NodeRunResult: | |||
| assigned_variable_selector = self.node_data.assigned_variable_selector | |||
| # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject | |||
| original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector) | |||
| original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) | |||
| if not isinstance(original_variable, Variable): | |||
| raise VariableOperatorNodeError("assigned variable not found") | |||
| @@ -44,20 +104,28 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): | |||
| raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") | |||
| # Over write the variable. | |||
| self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable) | |||
| self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) | |||
| # TODO: Move database operation to the pipeline. | |||
| # Update conversation variable. | |||
| conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) | |||
| if not conversation_id: | |||
| raise VariableOperatorNodeError("conversation_id not found") | |||
| common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) | |||
| conv_var_updater = self._conv_var_updater_factory() | |||
| conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable) | |||
| conv_var_updater.flush() | |||
| updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs={ | |||
| "value": income_value.to_object(), | |||
| }, | |||
| # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, | |||
| # we still set `output_variables` as a list to ensure the schema of output is | |||
| # compatible with `v2.VariableAssignerNode`. | |||
| process_data=common_helpers.set_updated_variables({}, updated_variables), | |||
| outputs={}, | |||
| ) | |||
| @@ -12,6 +12,12 @@ class VariableOperationItem(BaseModel): | |||
| variable_selector: Sequence[str] | |||
| input_type: InputType | |||
| operation: Operation | |||
| # NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context: | |||
| # | |||
| # 1. For CONSTANT input_type: Contains the literal value to be used in the operation. | |||
| # 2. For VARIABLE input_type: Initially contains the selector of the source variable. | |||
| # 3. During the variable updating procedure: The `value` field is reassigned to hold | |||
| # the resolved actual value that will be applied to the target variable. | |||
| value: Any | None = None | |||
| @@ -29,3 +29,8 @@ class InvalidInputValueError(VariableOperatorNodeError): | |||
| class ConversationIDNotFoundError(VariableOperatorNodeError): | |||
| def __init__(self): | |||
| super().__init__("conversation_id not found") | |||
| class InvalidDataError(VariableOperatorNodeError): | |||
| def __init__(self, message: str) -> None: | |||
| super().__init__(message) | |||
| @@ -1,34 +1,84 @@ | |||
| import json | |||
| from collections.abc import Sequence | |||
| from typing import Any, cast | |||
| from collections.abc import Callable, Mapping, MutableMapping, Sequence | |||
| from typing import Any, TypeAlias, cast | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.variables import SegmentType, Variable | |||
| from core.variables.consts import MIN_SELECTORS_LENGTH | |||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID | |||
| from core.workflow.conversation_variable_updater import ConversationVariableUpdater | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.variable_assigner.common import helpers as common_helpers | |||
| from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError | |||
| from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory | |||
| from . import helpers | |||
| from .constants import EMPTY_VALUE_MAPPING | |||
| from .entities import VariableAssignerNodeData | |||
| from .entities import VariableAssignerNodeData, VariableOperationItem | |||
| from .enums import InputType, Operation | |||
| from .exc import ( | |||
| ConversationIDNotFoundError, | |||
| InputTypeNotSupportedError, | |||
| InvalidDataError, | |||
| InvalidInputValueError, | |||
| OperationNotSupportedError, | |||
| VariableNotFoundError, | |||
| ) | |||
| _CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] | |||
| def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): | |||
| selector_node_id = item.variable_selector[0] | |||
| if selector_node_id != CONVERSATION_VARIABLE_NODE_ID: | |||
| return | |||
| selector_str = ".".join(item.variable_selector) | |||
| key = f"{node_id}.#{selector_str}#" | |||
| mapping[key] = item.variable_selector | |||
| def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): | |||
| # Keep this in sync with the logic in _run methods... | |||
| if item.input_type != InputType.VARIABLE: | |||
| return | |||
| selector = item.value | |||
| if not isinstance(selector, list): | |||
| raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") | |||
| if len(selector) < MIN_SELECTORS_LENGTH: | |||
| raise InvalidDataError(f"selector too short, {node_id=}, {item=}") | |||
| selector_str = ".".join(selector) | |||
| key = f"{node_id}.#{selector_str}#" | |||
| mapping[key] = selector | |||
| class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): | |||
| _node_data_cls = VariableAssignerNodeData | |||
| _node_type = NodeType.VARIABLE_ASSIGNER | |||
| def _conv_var_updater_factory(self) -> ConversationVariableUpdater: | |||
| return conversation_variable_updater_factory() | |||
| @classmethod | |||
| def version(cls) -> str: | |||
| return "2" | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| *, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: VariableAssignerNodeData, | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| var_mapping: dict[str, Sequence[str]] = {} | |||
| for item in node_data.items: | |||
| _target_mapping_from_item(var_mapping, node_id, item) | |||
| _source_mapping_from_item(var_mapping, node_id, item) | |||
| return var_mapping | |||
| def _run(self) -> NodeRunResult: | |||
| inputs = self.node_data.model_dump() | |||
| process_data: dict[str, Any] = {} | |||
| @@ -114,6 +164,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): | |||
| # remove the duplicated items first. | |||
| updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) | |||
| conv_var_updater = self._conv_var_updater_factory() | |||
| # Update variables | |||
| for selector in updated_variable_selectors: | |||
| variable = self.graph_runtime_state.variable_pool.get(selector) | |||
| @@ -128,15 +179,23 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): | |||
| raise ConversationIDNotFoundError | |||
| else: | |||
| conversation_id = conversation_id.value | |||
| common_helpers.update_conversation_variable( | |||
| conv_var_updater.update( | |||
| conversation_id=cast(str, conversation_id), | |||
| variable=variable, | |||
| ) | |||
| conv_var_updater.flush() | |||
| updated_variables = [ | |||
| common_helpers.variable_to_processed_data(selector, seg) | |||
| for selector in updated_variable_selectors | |||
| if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None | |||
| ] | |||
| process_data = common_helpers.set_updated_variables(process_data, updated_variables) | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=inputs, | |||
| process_data=process_data, | |||
| outputs={}, | |||
| ) | |||
| def _handle_item( | |||
| @@ -0,0 +1,29 @@ | |||
| from core.variables.segments import ObjectSegment, Segment | |||
| from core.workflow.entities.variable_pool import VariablePool, VariableValue | |||
| def append_variables_recursively( | |||
| pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment | |||
| ): | |||
| """ | |||
| Append variables recursively | |||
| :param pool: variable pool to append variables to | |||
| :param node_id: node id | |||
| :param variable_key_list: variable key list | |||
| :param variable_value: variable value | |||
| :return: | |||
| """ | |||
| pool.add([node_id] + variable_key_list, variable_value) | |||
| # if variable_value is a dict, then recursively append variables | |||
| if isinstance(variable_value, ObjectSegment): | |||
| variable_dict = variable_value.value | |||
| elif isinstance(variable_value, dict): | |||
| variable_dict = variable_value | |||
| else: | |||
| return | |||
| for key, value in variable_dict.items(): | |||
| # construct new key list | |||
| new_key_list = variable_key_list + [key] | |||
| append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value) | |||
| @@ -0,0 +1,84 @@ | |||
| import abc | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Protocol | |||
| from core.variables import Variable | |||
| from core.variables.consts import MIN_SELECTORS_LENGTH | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.utils import variable_utils | |||
| class VariableLoader(Protocol): | |||
| """Interface for loading variables based on selectors. | |||
| A `VariableLoader` is responsible for retrieving additional variables required during the execution | |||
| of a single node, which are not provided as user inputs. | |||
| NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same | |||
| application and share the same `app_id`. However, this interface does not enforce that constraint, | |||
| and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of | |||
| concern and allow for flexible implementations. | |||
| Implementations of `VariableLoader` should almost always have an `app_id` parameter in | |||
| their constructor. | |||
| TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into | |||
| `WorkflowService.single_step_run`, we may get rid of this interface. | |||
| """ | |||
| @abc.abstractmethod | |||
| def load_variables(self, selectors: list[list[str]]) -> list[Variable]: | |||
| """Load variables based on the provided selectors. If the selectors are empty, | |||
| this method should return an empty list. | |||
| The order of the returned variables is not guaranteed. If the caller wants to ensure | |||
| a specific order, they should sort the returned list themselves. | |||
| :param: selectors: a list of string list, each inner list should have at least two elements: | |||
| - the first element is the node ID, | |||
| - the second element is the variable name. | |||
| :return: a list of Variable objects that match the provided selectors. | |||
| """ | |||
| pass | |||
| class _DummyVariableLoader(VariableLoader): | |||
| """A dummy implementation of VariableLoader that does not load any variables. | |||
| Serves as a placeholder when no variable loading is needed. | |||
| """ | |||
| def load_variables(self, selectors: list[list[str]]) -> list[Variable]: | |||
| return [] | |||
| DUMMY_VARIABLE_LOADER = _DummyVariableLoader() | |||
| def load_into_variable_pool( | |||
| variable_loader: VariableLoader, | |||
| variable_pool: VariablePool, | |||
| variable_mapping: Mapping[str, Sequence[str]], | |||
| user_inputs: Mapping[str, Any], | |||
| ): | |||
| # Loading missing variable from draft var here, and set it into | |||
| # variable_pool. | |||
| variables_to_load: list[list[str]] = [] | |||
| for key, selector in variable_mapping.items(): | |||
| # NOTE(QuantumGhost): this logic needs to be in sync with | |||
| # `WorkflowEntry.mapping_user_inputs_to_variable_pool`. | |||
| node_variable_list = key.split(".") | |||
| if len(node_variable_list) < 1: | |||
| raise ValueError(f"Invalid variable key: {key}. It should have at least one element.") | |||
| if key in user_inputs: | |||
| continue | |||
| node_variable_key = ".".join(node_variable_list[1:]) | |||
| if node_variable_key in user_inputs: | |||
| continue | |||
| if variable_pool.get(selector) is None: | |||
| variables_to_load.append(list(selector)) | |||
| loaded = variable_loader.load_variables(variables_to_load) | |||
| for var in loaded: | |||
| assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}" | |||
| variable_utils.append_variables_recursively( | |||
| variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var | |||
| ) | |||
| @@ -92,7 +92,7 @@ class WorkflowCycleManager: | |||
| ) -> WorkflowExecution: | |||
| workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) | |||
| outputs = WorkflowEntry.handle_special_values(outputs) | |||
| # outputs = WorkflowEntry.handle_special_values(outputs) | |||
| workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED | |||
| workflow_execution.outputs = outputs or {} | |||
| @@ -125,7 +125,7 @@ class WorkflowCycleManager: | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> WorkflowExecution: | |||
| execution = self._get_workflow_execution_or_raise_error(workflow_run_id) | |||
| outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) | |||
| # outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) | |||
| execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED | |||
| execution.outputs = outputs or {} | |||
| @@ -242,9 +242,9 @@ class WorkflowCycleManager: | |||
| raise ValueError(f"Domain node execution not found: {event.node_execution_id}") | |||
| # Process data | |||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | |||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | |||
| inputs = event.inputs | |||
| process_data = event.process_data | |||
| outputs = event.outputs | |||
| # Convert metadata keys to strings | |||
| execution_metadata_dict = {} | |||
| @@ -289,7 +289,7 @@ class WorkflowCycleManager: | |||
| # Process data | |||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | |||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | |||
| outputs = event.outputs | |||
| # Convert metadata keys to strings | |||
| execution_metadata_dict = {} | |||
| @@ -326,7 +326,7 @@ class WorkflowCycleManager: | |||
| finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| elapsed_time = (finished_at - created_at).total_seconds() | |||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | |||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | |||
| outputs = event.outputs | |||
| # Convert metadata keys to strings | |||
| origin_metadata = { | |||
| @@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.event import NodeEvent | |||
| from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | |||
| from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool | |||
| from factories import file_factory | |||
| from models.enums import UserFrom | |||
| from models.workflow import ( | |||
| @@ -119,7 +120,9 @@ class WorkflowEntry: | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user_id: str, | |||
| user_inputs: dict, | |||
| user_inputs: Mapping[str, Any], | |||
| variable_pool: VariablePool, | |||
| variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | |||
| ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: | |||
| """ | |||
| Single step run workflow node | |||
| @@ -129,29 +132,14 @@ class WorkflowEntry: | |||
| :param user_inputs: user inputs | |||
| :return: | |||
| """ | |||
| # fetch node info from workflow graph | |||
| workflow_graph = workflow.graph_dict | |||
| if not workflow_graph: | |||
| raise ValueError("workflow graph not found") | |||
| nodes = workflow_graph.get("nodes") | |||
| if not nodes: | |||
| raise ValueError("nodes not found in workflow graph") | |||
| # fetch node config from node id | |||
| try: | |||
| node_config = next(filter(lambda node: node["id"] == node_id, nodes)) | |||
| except StopIteration: | |||
| raise ValueError("node id not found in workflow graph") | |||
| node_config = workflow.get_node_config_by_id(node_id) | |||
| node_config_data = node_config.get("data", {}) | |||
| # Get node class | |||
| node_type = NodeType(node_config.get("data", {}).get("type")) | |||
| node_version = node_config.get("data", {}).get("version", "1") | |||
| node_type = NodeType(node_config_data.get("type")) | |||
| node_version = node_config_data.get("version", "1") | |||
| node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] | |||
| # init variable pool | |||
| variable_pool = VariablePool(environment_variables=workflow.environment_variables) | |||
| # init graph | |||
| graph = Graph.init(graph_config=workflow.graph_dict) | |||
| @@ -182,16 +170,33 @@ class WorkflowEntry: | |||
| except NotImplementedError: | |||
| variable_mapping = {} | |||
| # Loading missing variable from draft var here, and set it into | |||
| # variable_pool. | |||
| load_into_variable_pool( | |||
| variable_loader=variable_loader, | |||
| variable_pool=variable_pool, | |||
| variable_mapping=variable_mapping, | |||
| user_inputs=user_inputs, | |||
| ) | |||
| cls.mapping_user_inputs_to_variable_pool( | |||
| variable_mapping=variable_mapping, | |||
| user_inputs=user_inputs, | |||
| variable_pool=variable_pool, | |||
| tenant_id=workflow.tenant_id, | |||
| ) | |||
| try: | |||
| # run node | |||
| generator = node_instance.run() | |||
| except Exception as e: | |||
| logger.exception( | |||
| "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", | |||
| workflow.id, | |||
| node_instance.id, | |||
| node_instance.node_type, | |||
| node_instance.version(), | |||
| ) | |||
| raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) | |||
| return node_instance, generator | |||
| @@ -294,10 +299,20 @@ class WorkflowEntry: | |||
| return node_instance, generator | |||
| except Exception as e: | |||
| logger.exception( | |||
| "error while running node_instance, node_id=%s, type=%s, version=%s", | |||
| node_instance.id, | |||
| node_instance.node_type, | |||
| node_instance.version(), | |||
| ) | |||
| raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) | |||
| @staticmethod | |||
| def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: | |||
| # NOTE(QuantumGhost): Avoid using this function in new code. | |||
| # Keep values structured as long as possible and only convert to dict | |||
| # immediately before serialization (e.g., JSON serialization) to maintain | |||
| # data integrity and type information. | |||
| result = WorkflowEntry._handle_special_values(value) | |||
| return result if isinstance(result, Mapping) or result is None else dict(result) | |||
| @@ -324,10 +339,17 @@ class WorkflowEntry: | |||
| cls, | |||
| *, | |||
| variable_mapping: Mapping[str, Sequence[str]], | |||
| user_inputs: dict, | |||
| user_inputs: Mapping[str, Any], | |||
| variable_pool: VariablePool, | |||
| tenant_id: str, | |||
| ) -> None: | |||
| # NOTE(QuantumGhost): This logic should remain synchronized with | |||
| # the implementation of `load_into_variable_pool`, specifically the logic about | |||
| # variable existence checking. | |||
| # WARNING(QuantumGhost): The semantics of this method are not clearly defined, | |||
| # and multiple parts of the codebase depend on its current behavior. | |||
| # Modify with caution. | |||
| for node_variable, variable_selector in variable_mapping.items(): | |||
| # fetch node id and variable key from node_variable | |||
| node_variable_list = node_variable.split(".") | |||
| @@ -0,0 +1,49 @@ | |||
| import json | |||
| from collections.abc import Mapping | |||
| from typing import Any | |||
| from pydantic import BaseModel | |||
| from core.file.models import File | |||
| from core.variables import Segment | |||
| class WorkflowRuntimeTypeEncoder(json.JSONEncoder): | |||
| def default(self, o: Any): | |||
| if isinstance(o, Segment): | |||
| return o.value | |||
| elif isinstance(o, File): | |||
| return o.to_dict() | |||
| elif isinstance(o, BaseModel): | |||
| return o.model_dump(mode="json") | |||
| else: | |||
| return super().default(o) | |||
| class WorkflowRuntimeTypeConverter: | |||
| def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: | |||
| result = self._to_json_encodable_recursive(value) | |||
| return result if isinstance(result, Mapping) or result is None else dict(result) | |||
| def _to_json_encodable_recursive(self, value: Any) -> Any: | |||
| if value is None: | |||
| return value | |||
| if isinstance(value, (bool, int, str, float)): | |||
| return value | |||
| if isinstance(value, Segment): | |||
| return self._to_json_encodable_recursive(value.value) | |||
| if isinstance(value, File): | |||
| return value.to_dict() | |||
| if isinstance(value, BaseModel): | |||
| return value.model_dump(mode="json") | |||
| if isinstance(value, dict): | |||
| res = {} | |||
| for k, v in value.items(): | |||
| res[k] = self._to_json_encodable_recursive(v) | |||
| return res | |||
| if isinstance(value, list): | |||
| res_list = [] | |||
| for item in value: | |||
| res_list.append(self._to_json_encodable_recursive(item)) | |||
| return res_list | |||
| return value | |||
| @@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle | |||
| from .create_document_index import handle | |||
| from .create_installed_app_when_app_created import handle | |||
| from .create_site_record_when_app_created import handle | |||
| from .deduct_quota_when_message_created import handle | |||
| from .delete_tool_parameters_cache_when_sync_draft_workflow import handle | |||
| from .update_app_dataset_join_when_app_model_config_updated import handle | |||
| from .update_app_dataset_join_when_app_published_workflow_updated import handle | |||
| from .update_provider_last_used_at_when_message_created import handle | |||
| # Consolidated handler replaces both deduct_quota_when_message_created and | |||
| # update_provider_last_used_at_when_message_created | |||
| from .update_provider_when_message_created import handle | |||
| @@ -1,65 +0,0 @@ | |||
| from datetime import UTC, datetime | |||
| from configs import dify_config | |||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity | |||
| from core.entities.provider_entities import QuotaUnit | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models.provider import Provider, ProviderType | |||
| @message_was_created.connect | |||
| def handle(sender, **kwargs): | |||
| message = sender | |||
| application_generate_entity = kwargs.get("application_generate_entity") | |||
| if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): | |||
| return | |||
| model_config = application_generate_entity.model_conf | |||
| provider_model_bundle = model_config.provider_model_bundle | |||
| provider_configuration = provider_model_bundle.configuration | |||
| if provider_configuration.using_provider_type != ProviderType.SYSTEM: | |||
| return | |||
| system_configuration = provider_configuration.system_configuration | |||
| if not system_configuration.current_quota_type: | |||
| return | |||
| quota_unit = None | |||
| for quota_configuration in system_configuration.quota_configurations: | |||
| if quota_configuration.quota_type == system_configuration.current_quota_type: | |||
| quota_unit = quota_configuration.quota_unit | |||
| if quota_configuration.quota_limit == -1: | |||
| return | |||
| break | |||
| used_quota = None | |||
| if quota_unit: | |||
| if quota_unit == QuotaUnit.TOKENS: | |||
| used_quota = message.message_tokens + message.answer_tokens | |||
| elif quota_unit == QuotaUnit.CREDITS: | |||
| used_quota = dify_config.get_model_credits(model_config.model) | |||
| else: | |||
| used_quota = 1 | |||
| if used_quota is not None and system_configuration.current_quota_type is not None: | |||
| db.session.query(Provider).filter( | |||
| Provider.tenant_id == application_generate_entity.app_config.tenant_id, | |||
| # TODO: Use provider name with prefix after the data migration. | |||
| Provider.provider_name == ModelProviderID(model_config.provider).provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == system_configuration.current_quota_type.value, | |||
| Provider.quota_limit > Provider.quota_used, | |||
| ).update( | |||
| { | |||
| "quota_used": Provider.quota_used + used_quota, | |||
| "last_used": datetime.now(tz=UTC).replace(tzinfo=None), | |||
| } | |||
| ) | |||
| db.session.commit() | |||
| @@ -1,20 +0,0 @@ | |||
| from datetime import UTC, datetime | |||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models.provider import Provider | |||
| @message_was_created.connect | |||
| def handle(sender, **kwargs): | |||
| application_generate_entity = kwargs.get("application_generate_entity") | |||
| if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): | |||
| return | |||
| db.session.query(Provider).filter( | |||
| Provider.tenant_id == application_generate_entity.app_config.tenant_id, | |||
| Provider.provider_name == application_generate_entity.model_conf.provider, | |||
| ).update({"last_used": datetime.now(UTC).replace(tzinfo=None)}) | |||
| db.session.commit() | |||
| @@ -0,0 +1,234 @@ | |||
| import logging | |||
| import time as time_module | |||
| from datetime import datetime | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| from sqlalchemy import update | |||
| from sqlalchemy.orm import Session | |||
| from configs import dify_config | |||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity | |||
| from core.entities.provider_entities import QuotaUnit, SystemConfiguration | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from libs import datetime_utils | |||
| from models.model import Message | |||
| from models.provider import Provider, ProviderType | |||
| logger = logging.getLogger(__name__) | |||
| class _ProviderUpdateFilters(BaseModel): | |||
| """Filters for identifying Provider records to update.""" | |||
| tenant_id: str | |||
| provider_name: str | |||
| provider_type: Optional[str] = None | |||
| quota_type: Optional[str] = None | |||
| class _ProviderUpdateAdditionalFilters(BaseModel): | |||
| """Additional filters for Provider updates.""" | |||
| quota_limit_check: bool = False | |||
| class _ProviderUpdateValues(BaseModel): | |||
| """Values to update in Provider records.""" | |||
| last_used: Optional[datetime] = None | |||
| quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression | |||
| class _ProviderUpdateOperation(BaseModel): | |||
| """A single Provider update operation.""" | |||
| filters: _ProviderUpdateFilters | |||
| values: _ProviderUpdateValues | |||
| additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters() | |||
| description: str = "unknown" | |||
| @message_was_created.connect | |||
| def handle(sender: Message, **kwargs): | |||
| """ | |||
| Consolidated handler for Provider updates when a message is created. | |||
| This handler replaces both: | |||
| - update_provider_last_used_at_when_message_created | |||
| - deduct_quota_when_message_created | |||
| By performing all Provider updates in a single transaction, we ensure | |||
| consistency and efficiency when updating Provider records. | |||
| """ | |||
| message = sender | |||
| application_generate_entity = kwargs.get("application_generate_entity") | |||
| if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): | |||
| return | |||
| tenant_id = application_generate_entity.app_config.tenant_id | |||
| provider_name = application_generate_entity.model_conf.provider | |||
| current_time = datetime_utils.naive_utc_now() | |||
| # Prepare updates for both scenarios | |||
| updates_to_perform: list[_ProviderUpdateOperation] = [] | |||
| # 1. Always update last_used for the provider | |||
| basic_update = _ProviderUpdateOperation( | |||
| filters=_ProviderUpdateFilters( | |||
| tenant_id=tenant_id, | |||
| provider_name=provider_name, | |||
| ), | |||
| values=_ProviderUpdateValues(last_used=current_time), | |||
| description="basic_last_used_update", | |||
| ) | |||
| updates_to_perform.append(basic_update) | |||
| # 2. Check if we need to deduct quota (system provider only) | |||
| model_config = application_generate_entity.model_conf | |||
| provider_model_bundle = model_config.provider_model_bundle | |||
| provider_configuration = provider_model_bundle.configuration | |||
| if ( | |||
| provider_configuration.using_provider_type == ProviderType.SYSTEM | |||
| and provider_configuration.system_configuration | |||
| and provider_configuration.system_configuration.current_quota_type is not None | |||
| ): | |||
| system_configuration = provider_configuration.system_configuration | |||
| # Calculate quota usage | |||
| used_quota = _calculate_quota_usage( | |||
| message=message, | |||
| system_configuration=system_configuration, | |||
| model_name=model_config.model, | |||
| ) | |||
| if used_quota is not None: | |||
| quota_update = _ProviderUpdateOperation( | |||
| filters=_ProviderUpdateFilters( | |||
| tenant_id=tenant_id, | |||
| provider_name=ModelProviderID(model_config.provider).provider_name, | |||
| provider_type=ProviderType.SYSTEM.value, | |||
| quota_type=provider_configuration.system_configuration.current_quota_type.value, | |||
| ), | |||
| values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), | |||
| additional_filters=_ProviderUpdateAdditionalFilters( | |||
| quota_limit_check=True # Provider.quota_limit > Provider.quota_used | |||
| ), | |||
| description="quota_deduction_update", | |||
| ) | |||
| updates_to_perform.append(quota_update) | |||
| # Execute all updates | |||
| start_time = time_module.perf_counter() | |||
| try: | |||
| _execute_provider_updates(updates_to_perform) | |||
| # Log successful completion with timing | |||
| duration = time_module.perf_counter() - start_time | |||
| logger.info( | |||
| f"Provider updates completed successfully. " | |||
| f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, " | |||
| f"Tenant: {tenant_id}, Provider: {provider_name}" | |||
| ) | |||
| except Exception as e: | |||
| # Log failure with timing and context | |||
| duration = time_module.perf_counter() - start_time | |||
| logger.exception( | |||
| f"Provider updates failed after {duration:.3f}s. " | |||
| f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, " | |||
| f"Provider: {provider_name}" | |||
| ) | |||
| raise | |||
| def _calculate_quota_usage( | |||
| *, message: Message, system_configuration: SystemConfiguration, model_name: str | |||
| ) -> Optional[int]: | |||
| """Calculate quota usage based on message tokens and quota type.""" | |||
| quota_unit = None | |||
| for quota_configuration in system_configuration.quota_configurations: | |||
| if quota_configuration.quota_type == system_configuration.current_quota_type: | |||
| quota_unit = quota_configuration.quota_unit | |||
| if quota_configuration.quota_limit == -1: | |||
| return None | |||
| break | |||
| if quota_unit is None: | |||
| return None | |||
| try: | |||
| if quota_unit == QuotaUnit.TOKENS: | |||
| tokens = message.message_tokens + message.answer_tokens | |||
| return tokens | |||
| if quota_unit == QuotaUnit.CREDITS: | |||
| tokens = dify_config.get_model_credits(model_name) | |||
| return tokens | |||
| elif quota_unit == QuotaUnit.TIMES: | |||
| return 1 | |||
| return None | |||
| except Exception as e: | |||
| logger.exception("Failed to calculate quota usage") | |||
| return None | |||
| def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]): | |||
| """Execute all Provider updates in a single transaction.""" | |||
| if not updates_to_perform: | |||
| return | |||
| # Use SQLAlchemy's context manager for transaction management | |||
| # This automatically handles commit/rollback | |||
| with Session(db.engine) as session: | |||
| # Use a single transaction for all updates | |||
| for update_operation in updates_to_perform: | |||
| filters = update_operation.filters | |||
| values = update_operation.values | |||
| additional_filters = update_operation.additional_filters | |||
| description = update_operation.description | |||
| # Build the where conditions | |||
| where_conditions = [ | |||
| Provider.tenant_id == filters.tenant_id, | |||
| Provider.provider_name == filters.provider_name, | |||
| ] | |||
| # Add additional filters if specified | |||
| if filters.provider_type is not None: | |||
| where_conditions.append(Provider.provider_type == filters.provider_type) | |||
| if filters.quota_type is not None: | |||
| where_conditions.append(Provider.quota_type == filters.quota_type) | |||
| if additional_filters.quota_limit_check: | |||
| where_conditions.append(Provider.quota_limit > Provider.quota_used) | |||
| # Prepare values dict for SQLAlchemy update | |||
| update_values = {} | |||
| if values.last_used is not None: | |||
| update_values["last_used"] = values.last_used | |||
| if values.quota_used is not None: | |||
| update_values["quota_used"] = values.quota_used | |||
| # Build and execute the update statement | |||
| stmt = update(Provider).where(*where_conditions).values(**update_values) | |||
| result = session.execute(stmt) | |||
| rows_affected = result.rowcount | |||
| logger.debug( | |||
| f"Provider update ({description}): {rows_affected} rows affected. " | |||
| f"Filters: {filters.model_dump()}, Values: {update_values}" | |||
| ) | |||
| # If no rows were affected for quota updates, log a warning | |||
| if rows_affected == 0 and description == "quota_deduction_update": | |||
| logger.warning( | |||
| f"No Provider rows updated for quota deduction. " | |||
| f"This may indicate quota limit exceeded or provider not found. " | |||
| f"Filters: {filters.model_dump()}" | |||
| ) | |||
| logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates") | |||
| @@ -5,6 +5,7 @@ from typing import Any, cast | |||
| import httpx | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS | |||
| from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers | |||
| @@ -91,6 +92,8 @@ def build_from_mappings( | |||
| tenant_id: str, | |||
| strict_type_validation: bool = False, | |||
| ) -> Sequence[File]: | |||
| # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. | |||
| # Implement batch processing to reduce database load when handling multiple files. | |||
| files = [ | |||
| build_from_mapping( | |||
| mapping=mapping, | |||
| @@ -377,3 +380,75 @@ def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: | |||
| def get_file_type_by_mime_type(mime_type: str) -> FileType: | |||
| return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM | |||
| class StorageKeyLoader: | |||
| """FileKeyLoader load the storage key from database for a list of files. | |||
| This loader is batched, the database query count is constant regardless of the input size. | |||
| """ | |||
| def __init__(self, session: Session, tenant_id: str) -> None: | |||
| self._session = session | |||
| self._tenant_id = tenant_id | |||
| def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: | |||
| stmt = select(UploadFile).where( | |||
| UploadFile.id.in_(upload_file_ids), | |||
| UploadFile.tenant_id == self._tenant_id, | |||
| ) | |||
| return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} | |||
| def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: | |||
| stmt = select(ToolFile).where( | |||
| ToolFile.id.in_(tool_file_ids), | |||
| ToolFile.tenant_id == self._tenant_id, | |||
| ) | |||
| return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} | |||
| def load_storage_keys(self, files: Sequence[File]): | |||
| """Loads storage keys for a sequence of files by retrieving the corresponding | |||
| `UploadFile` or `ToolFile` records from the database based on their transfer method. | |||
| This method doesn't modify the input sequence structure but updates the `_storage_key` | |||
| property of each file object by extracting the relevant key from its database record. | |||
| Performance note: This is a batched operation where database query count remains constant | |||
| regardless of input size. However, for optimal performance, input sequences should contain | |||
| fewer than 1000 files. For larger collections, split into smaller batches and process each | |||
| batch separately. | |||
| """ | |||
| upload_file_ids: list[uuid.UUID] = [] | |||
| tool_file_ids: list[uuid.UUID] = [] | |||
| for file in files: | |||
| related_model_id = file.related_id | |||
| if file.related_id is None: | |||
| raise ValueError("file id should not be None.") | |||
| if file.tenant_id != self._tenant_id: | |||
| err_msg = ( | |||
| f"invalid file, expected tenant_id={self._tenant_id}, " | |||
| f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" | |||
| ) | |||
| raise ValueError(err_msg) | |||
| model_id = uuid.UUID(related_model_id) | |||
| if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): | |||
| upload_file_ids.append(model_id) | |||
| elif file.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| tool_file_ids.append(model_id) | |||
| tool_files = self._load_tool_files(tool_file_ids) | |||
| upload_files = self._load_upload_files(upload_file_ids) | |||
| for file in files: | |||
| model_id = uuid.UUID(file.related_id) | |||
| if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): | |||
| upload_file_row = upload_files.get(model_id) | |||
| if upload_file_row is None: | |||
| raise ValueError(f"Upload file not found for id: {model_id}") | |||
| file._storage_key = upload_file_row.key | |||
| elif file.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| tool_file_row = tool_files.get(model_id) | |||
| if tool_file_row is None: | |||
| raise ValueError(f"Tool file not found for id: {model_id}") | |||
| file._storage_key = tool_file_row.file_key | |||
| @@ -43,6 +43,10 @@ class UnsupportedSegmentTypeError(Exception): | |||
| pass | |||
| class TypeMismatchError(Exception): | |||
| pass | |||
| # Define the constant | |||
| SEGMENT_TO_VARIABLE_MAP = { | |||
| StringSegment: StringVariable, | |||
| @@ -110,6 +114,10 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen | |||
| return cast(Variable, result) | |||
| def infer_segment_type_from_value(value: Any, /) -> SegmentType: | |||
| return build_segment(value).value_type | |||
| def build_segment(value: Any, /) -> Segment: | |||
| if value is None: | |||
| return NoneSegment() | |||
| @@ -140,10 +148,80 @@ def build_segment(value: Any, /) -> Segment: | |||
| case SegmentType.NONE: | |||
| return ArrayAnySegment(value=value) | |||
| case _: | |||
| # This should be unreachable. | |||
| raise ValueError(f"not supported value {value}") | |||
| raise ValueError(f"not supported value {value}") | |||
| def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: | |||
| """ | |||
| Build a segment with explicit type checking. | |||
| This function creates a segment from a value while enforcing type compatibility | |||
| with the specified segment_type. It provides stricter type validation compared | |||
| to the standard build_segment function. | |||
| Args: | |||
| segment_type: The expected SegmentType for the resulting segment | |||
| value: The value to be converted into a segment | |||
| Returns: | |||
| Segment: A segment instance of the appropriate type | |||
| Raises: | |||
| TypeMismatchError: If the value type doesn't match the expected segment_type | |||
| Special Cases: | |||
| - For empty list [] values, if segment_type is array[*], returns the corresponding array type | |||
| - Type validation is performed before segment creation | |||
| Examples: | |||
| >>> build_segment_with_type(SegmentType.STRING, "hello") | |||
| StringSegment(value="hello") | |||
| >>> build_segment_with_type(SegmentType.ARRAY_STRING, []) | |||
| ArrayStringSegment(value=[]) | |||
| >>> build_segment_with_type(SegmentType.STRING, 123) | |||
| # Raises TypeMismatchError | |||
| """ | |||
| # Handle None values | |||
| if value is None: | |||
| if segment_type == SegmentType.NONE: | |||
| return NoneSegment() | |||
| else: | |||
| raise TypeMismatchError(f"Expected {segment_type}, but got None") | |||
| # Handle empty list special case for array types | |||
| if isinstance(value, list) and len(value) == 0: | |||
| if segment_type == SegmentType.ARRAY_ANY: | |||
| return ArrayAnySegment(value=value) | |||
| elif segment_type == SegmentType.ARRAY_STRING: | |||
| return ArrayStringSegment(value=value) | |||
| elif segment_type == SegmentType.ARRAY_NUMBER: | |||
| return ArrayNumberSegment(value=value) | |||
| elif segment_type == SegmentType.ARRAY_OBJECT: | |||
| return ArrayObjectSegment(value=value) | |||
| elif segment_type == SegmentType.ARRAY_FILE: | |||
| return ArrayFileSegment(value=value) | |||
| else: | |||
| raise TypeMismatchError(f"Expected {segment_type}, but got empty list") | |||
| # Build segment using existing logic to infer actual type | |||
| inferred_segment = build_segment(value) | |||
| inferred_type = inferred_segment.value_type | |||
| # Type compatibility checking | |||
| if inferred_type == segment_type: | |||
| return inferred_segment | |||
| # Type mismatch - raise error with descriptive message | |||
| raise TypeMismatchError( | |||
| f"Type mismatch: expected {segment_type}, but value '{value}' " | |||
| f"(type: {type(value).__name__}) corresponds to {inferred_type}" | |||
| ) | |||
| def segment_to_variable( | |||
| *, | |||
| segment: Segment, | |||
| @@ -0,0 +1,22 @@ | |||
| import abc | |||
| import datetime | |||
| from typing import Protocol | |||
| class _NowFunction(Protocol): | |||
| @abc.abstractmethod | |||
| def __call__(self, tz: datetime.timezone | None) -> datetime.datetime: | |||
| pass | |||
| # _now_func is a callable with the _NowFunction signature. | |||
| # Its sole purpose is to abstract time retrieval, enabling | |||
| # developers to mock this behavior in tests and time-dependent scenarios. | |||
| _now_func: _NowFunction = datetime.datetime.now | |||
| def naive_utc_now() -> datetime.datetime: | |||
| """Return a naive datetime object (without timezone information) | |||
| representing current UTC time. | |||
| """ | |||
| return _now_func(datetime.UTC).replace(tzinfo=None) | |||
| @@ -0,0 +1,11 @@ | |||
| import json | |||
| from pydantic import BaseModel | |||
| class PydanticModelEncoder(json.JSONEncoder): | |||
| def default(self, o): | |||
| if isinstance(o, BaseModel): | |||
| return o.model_dump() | |||
| else: | |||
| super().default(o) | |||
| @@ -22,7 +22,11 @@ class SMTPClient: | |||
| if self.use_tls: | |||
| if self.opportunistic_tls: | |||
| smtp = smtplib.SMTP(self.server, self.port, timeout=10) | |||
| # Send EHLO command with the HELO domain name as the server address | |||
| smtp.ehlo(self.server) | |||
| smtp.starttls() | |||
| # Resend EHLO command to identify the TLS session | |||
| smtp.ehlo(self.server) | |||
| else: | |||
| smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) | |||
| else: | |||
| @@ -0,0 +1,20 @@ | |||
| """All these exceptions are not meant to be caught by callers.""" | |||
| class WorkflowDataError(Exception): | |||
| """Base class for all workflow data related exceptions. | |||
| This should be used to indicate issues with workflow data integrity, such as | |||
| no `graph` configuration, missing `nodes` field in `graph` configuration, or | |||
| similar issues. | |||
| """ | |||
| pass | |||
| class NodeNotFoundError(WorkflowDataError): | |||
| """Raised when a node with the specified ID is not found in the workflow.""" | |||
| def __init__(self, node_id: str): | |||
| super().__init__(f"Node with ID '{node_id}' not found in the workflow.") | |||
| self.node_id = node_id | |||