| @@ -8,13 +8,15 @@ body: | |||
| label: Self Checks | |||
| description: "To make sure we get to you in time, please check the following :)" | |||
| options: | |||
| - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). | |||
| required: true | |||
| - label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). | |||
| required: true | |||
| - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. | |||
| required: true | |||
| - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). | |||
| - label: I confirm that I am using English to submit this report, otherwise it will be closed. | |||
| required: true | |||
| - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" | |||
| - label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :) | |||
| required: true | |||
| - label: "Please do not modify this template :) and fill in all the required fields." | |||
| required: true | |||
| @@ -42,20 +44,22 @@ body: | |||
| attributes: | |||
| label: Steps to reproduce | |||
| description: We highly suggest including screenshots and a bug report log. Please use the right markdown syntax for code blocks. | |||
| placeholder: Having detailed steps helps us reproduce the bug. | |||
| placeholder: Having detailed steps helps us reproduce the bug. If you have logs, please use fenced code blocks (triple backticks ```) to format them. | |||
| validations: | |||
| required: true | |||
| - type: textarea | |||
| attributes: | |||
| label: ✔️ Expected Behavior | |||
| placeholder: What were you expecting? | |||
| description: Describe what you expected to happen. | |||
| placeholder: What were you expecting? Please do not copy and paste the steps to reproduce here. | |||
| validations: | |||
| required: false | |||
| required: true | |||
| - type: textarea | |||
| attributes: | |||
| label: ❌ Actual Behavior | |||
| placeholder: What happened instead? | |||
| description: Describe what actually happened. | |||
| placeholder: What happened instead? Please do not copy and paste the steps to reproduce here. | |||
| validations: | |||
| required: false | |||
| @@ -1,5 +1,11 @@ | |||
| blank_issues_enabled: false | |||
| contact_links: | |||
| - name: "\U0001F4A1 Model Providers & Plugins" | |||
| url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose" | |||
| about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details. | |||
| - name: "\U0001F4AC Documentation Issues" | |||
| url: "https://github.com/langgenius/dify-docs/issues/new" | |||
| about: Report issues with the documentation, such as typos, outdated information, or missing content. Please provide the specific section and details of the issue. | |||
| - name: "\U0001F4E7 Discussions" | |||
| url: https://github.com/langgenius/dify/discussions/categories/general | |||
| about: General discussions and request help from the community | |||
| about: General discussions and seek help from the community | |||
| @@ -1,24 +0,0 @@ | |||
| name: "📚 Documentation Issue" | |||
| description: Report issues in our documentation | |||
| labels: | |||
| - documentation | |||
| body: | |||
| - type: checkboxes | |||
| attributes: | |||
| label: Self Checks | |||
| description: "To make sure we get to you in time, please check the following :)" | |||
| options: | |||
| - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. | |||
| required: true | |||
| - label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). | |||
| required: true | |||
| - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" | |||
| required: true | |||
| - label: "Please do not modify this template :) and fill in all the required fields." | |||
| required: true | |||
| - type: textarea | |||
| attributes: | |||
| label: Provide a description of requested docs changes | |||
| placeholder: Briefly describe which document needs to be corrected and why. | |||
| validations: | |||
| required: true | |||
| @@ -8,11 +8,11 @@ body: | |||
| label: Self Checks | |||
| description: "To make sure we get to you in time, please check the following :)" | |||
| options: | |||
| - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. | |||
| - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). | |||
| required: true | |||
| - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). | |||
| - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. | |||
| required: true | |||
| - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" | |||
| - label: I confirm that I am using English to submit this report, otherwise it will be closed. | |||
| required: true | |||
| - label: "Please do not modify this template :) and fill in all the required fields." | |||
| required: true | |||
| @@ -1,55 +0,0 @@ | |||
| name: "🌐 Localization/Translation issue" | |||
| description: Report incorrect translations. [please use English :)] | |||
| labels: | |||
| - translation | |||
| body: | |||
| - type: checkboxes | |||
| attributes: | |||
| label: Self Checks | |||
| description: "To make sure we get to you in time, please check the following :)" | |||
| options: | |||
| - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. | |||
| required: true | |||
| - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). | |||
| required: true | |||
| - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" | |||
| required: true | |||
| - label: "Please do not modify this template :) and fill in all the required fields." | |||
| required: true | |||
| - type: input | |||
| attributes: | |||
| label: Dify version | |||
| description: Hover over system tray icon or look at Settings | |||
| validations: | |||
| required: true | |||
| - type: input | |||
| attributes: | |||
| label: Utility with translation issue | |||
| placeholder: Some area | |||
| description: Please input here the utility with the translation issue | |||
| validations: | |||
| required: true | |||
| - type: input | |||
| attributes: | |||
| label: 🌐 Language affected | |||
| placeholder: "German" | |||
| validations: | |||
| required: true | |||
| - type: textarea | |||
| attributes: | |||
| label: ❌ Actual phrase(s) | |||
| placeholder: What is there? Please include a screenshot as that is extremely helpful. | |||
| validations: | |||
| required: true | |||
| - type: textarea | |||
| attributes: | |||
| label: ✔️ Expected phrase(s) | |||
| placeholder: What was expected? | |||
| validations: | |||
| required: true | |||
| - type: textarea | |||
| attributes: | |||
| label: ℹ Why is the current translation wrong | |||
| placeholder: Why do you feel this is incorrect? | |||
| validations: | |||
| required: true | |||
| @@ -65,7 +65,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i | |||
| </br> | |||
| The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: | |||
| The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: | |||
| ```bash | |||
| cd dify | |||
| @@ -205,6 +205,7 @@ If you'd like to configure a highly-available setup, there are community-contrib | |||
| - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) | |||
| - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### Using Terraform for Deployment | |||
| @@ -261,8 +262,8 @@ At the same time, please consider supporting Dify by sharing it on social media | |||
| ## Security disclosure | |||
| To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer. | |||
| To protect your privacy, please avoid posting security issues on GitHub. Instead, report issues to security@dify.ai, and our team will respond with detailed answer. | |||
| ## License | |||
| This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. | |||
| This repository is licensed under the [Dify Open Source License](LICENSE), based on Apache 2.0 with additional conditions. | |||
| @@ -188,6 +188,7 @@ docker compose up -d | |||
| - [رسم بياني Helm من قبل @magicsong](https://github.com/magicsong/ai-charts) | |||
| - [ملف YAML من قبل @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [ملف YAML من قبل @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 جديد! ملفات YAML (تدعم Dify v1.6.0) بواسطة @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### استخدام Terraform للتوزيع | |||
| @@ -204,6 +204,8 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন | |||
| - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) | |||
| - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 নতুন! YAML ফাইলসমূহ (Dify v1.6.0 সমর্থিত) তৈরি করেছেন @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### টেরাফর্ম ব্যবহার করে ডিপ্লয় | |||
| @@ -194,9 +194,9 @@ docker compose up -d | |||
| 如果您需要自定义配置,请参考 [.env.example](docker/.env.example) 文件中的注释,并更新 `.env` 文件中对应的值。此外,您可能需要根据您的具体部署环境和需求对 `docker-compose.yaml` 文件本身进行调整,例如更改镜像版本、端口映射或卷挂载。完成任何更改后,请重新运行 `docker-compose up -d`。您可以在[此处](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用环境变量的完整列表。 | |||
| #### 使用 Helm Chart 部署 | |||
| #### 使用 Helm Chart 或 Kubernetes 资源清单(YAML)部署 | |||
| 使用 [Helm Chart](https://helm.sh/) 版本或者 YAML 文件,可以在 Kubernetes 上部署 Dify。 | |||
| 使用 [Helm Chart](https://helm.sh/) 版本或者 Kubernetes 资源清单(YAML),可以在 Kubernetes 上部署 Dify。 | |||
| - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) | |||
| - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) | |||
| @@ -204,6 +204,10 @@ docker compose up -d | |||
| - [YAML 文件 by @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 NEW! YAML 文件 (支持 Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### 使用 Terraform 部署 | |||
| 使用 [terraform](https://www.terraform.io/) 一键将 Dify 部署到云平台 | |||
| @@ -203,6 +203,7 @@ Falls Sie eine hochverfügbare Konfiguration einrichten möchten, gibt es von de | |||
| - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) | |||
| - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### Terraform für die Bereitstellung verwenden | |||
| @@ -203,6 +203,7 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop | |||
| - [Gráfico Helm por @magicsong](https://github.com/magicsong/ai-charts) | |||
| - [Ficheros YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [Ficheros YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 ¡NUEVO! Archivos YAML (compatible con Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### Uso de Terraform para el despliegue | |||
| @@ -201,6 +201,7 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau | |||
| - [Helm Chart par @magicsong](https://github.com/magicsong/ai-charts) | |||
| - [Fichier YAML par @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [Fichier YAML par @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 NOUVEAU ! Fichiers YAML (compatible avec Dify v1.6.0) par @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### Utilisation de Terraform pour le déploiement | |||
| @@ -202,6 +202,7 @@ docker compose up -d | |||
| - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) | |||
| - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 新着!YAML ファイル(Dify v1.6.0 対応)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### Terraformを使用したデプロイ | |||
| @@ -201,6 +201,7 @@ If you'd like to configure a highly-available setup, there are community-contrib | |||
| - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) | |||
| - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### Terraform atorlugu pilersitsineq | |||
| @@ -195,6 +195,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 | |||
| - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) | |||
| - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### Terraform을 사용한 배포 | |||
| @@ -200,6 +200,7 @@ Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts] | |||
| - [Helm Chart de @magicsong](https://github.com/magicsong/ai-charts) | |||
| - [Arquivo YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [Arquivo YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 NOVO! Arquivos YAML (Compatível com Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### Usando o Terraform para Implantação | |||
| @@ -201,6 +201,7 @@ Star Dify on GitHub and be instantly notified of new releases. | |||
| - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) | |||
| - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### Uporaba Terraform za uvajanje | |||
| @@ -194,6 +194,7 @@ Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify' | |||
| - [@BorisPolonsky tarafından Helm Chart](https://github.com/BorisPolonsky/dify-helm) | |||
| - [@Winson-030 tarafından YAML dosyası](https://github.com/Winson-030/dify-kubernetes) | |||
| - [@wyy-holding tarafından YAML dosyası](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 YENİ! YAML dosyaları (Dify v1.6.0 destekli) @Zhoneym tarafından](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### Dağıtım için Terraform Kullanımı | |||
| @@ -197,12 +197,13 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify | |||
| 如果您需要自定義配置,請參考我們的 [.env.example](docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。 | |||
| 如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 YAML 文件允許在 Kubernetes 上部署 Dify。 | |||
| 如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 Kubernetes 資源清單(YAML)允許在 Kubernetes 上部署 Dify。 | |||
| - [由 @LeoQuote 提供的 Helm Chart](https://github.com/douban/charts/tree/master/charts/dify) | |||
| - [由 @BorisPolonsky 提供的 Helm Chart](https://github.com/BorisPolonsky/dify-helm) | |||
| - [由 @Winson-030 提供的 YAML 文件](https://github.com/Winson-030/dify-kubernetes) | |||
| - [由 @wyy-holding 提供的 YAML 文件](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 NEW! YAML 檔案(支援 Dify v1.6.0)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| ### 使用 Terraform 進行部署 | |||
| @@ -196,6 +196,7 @@ Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có | |||
| - [Helm Chart bởi @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) | |||
| - [Tệp YAML bởi @Winson-030](https://github.com/Winson-030/dify-kubernetes) | |||
| - [Tệp YAML bởi @wyy-holding](https://github.com/wyy-holding/dify-k8s) | |||
| - [🚀 MỚI! Tệp YAML (Hỗ trợ Dify v1.6.0) bởi @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) | |||
| #### Sử dụng Terraform để Triển khai | |||
| @@ -449,6 +449,19 @@ MAX_VARIABLE_SIZE=204800 | |||
| # hybrid: Save new data to object storage, read from both object storage and RDBMS | |||
| WORKFLOW_NODE_EXECUTION_STORAGE=rdbms | |||
| # Repository configuration | |||
| # Core workflow execution repository implementation | |||
| CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository | |||
| # Core workflow node execution repository implementation | |||
| CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository | |||
| # API workflow node execution repository implementation | |||
| API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository | |||
| # API workflow run repository implementation | |||
| API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository | |||
| # App configuration | |||
| APP_MAX_EXECUTION_TIME=1200 | |||
| APP_MAX_ACTIVE_REQUESTS=0 | |||
| @@ -482,6 +495,8 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} | |||
| # Reset password token expiry minutes | |||
| RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 | |||
| CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 | |||
| OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 | |||
| CREATE_TIDB_SERVICE_JOB_ENABLED=false | |||
| @@ -492,6 +507,8 @@ LOGIN_LOCKOUT_DURATION=86400 | |||
| # Enable OpenTelemetry | |||
| ENABLE_OTEL=false | |||
| OTLP_TRACE_ENDPOINT= | |||
| OTLP_METRIC_ENDPOINT= | |||
| OTLP_BASE_ENDPOINT=http://localhost:4318 | |||
| OTLP_API_KEY= | |||
| OTEL_EXPORTER_OTLP_PROTOCOL= | |||
| @@ -31,6 +31,15 @@ class SecurityConfig(BaseSettings): | |||
| description="Duration in minutes for which a password reset token remains valid", | |||
| default=5, | |||
| ) | |||
| CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( | |||
| description="Duration in minutes for which a change email token remains valid", | |||
| default=5, | |||
| ) | |||
| OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( | |||
| description="Duration in minutes for which a owner transfer token remains valid", | |||
| default=5, | |||
| ) | |||
| LOGIN_DISABLED: bool = Field( | |||
| description="Whether to disable login checks", | |||
| @@ -537,6 +546,33 @@ class WorkflowNodeExecutionConfig(BaseSettings): | |||
| ) | |||
| class RepositoryConfig(BaseSettings): | |||
| """ | |||
| Configuration for repository implementations | |||
| """ | |||
| CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field( | |||
| description="Repository implementation for WorkflowExecution. Specify as a module path", | |||
| default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", | |||
| ) | |||
| CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( | |||
| description="Repository implementation for WorkflowNodeExecution. Specify as a module path", | |||
| default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", | |||
| ) | |||
| API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( | |||
| description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. " | |||
| "Specify as a module path", | |||
| default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository", | |||
| ) | |||
| API_WORKFLOW_RUN_REPOSITORY: str = Field( | |||
| description="Service-layer repository implementation for WorkflowRun operations. Specify as a module path", | |||
| default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository", | |||
| ) | |||
| class AuthConfig(BaseSettings): | |||
| """ | |||
| Configuration for authentication and OAuth | |||
| @@ -587,6 +623,16 @@ class AuthConfig(BaseSettings): | |||
| default=86400, | |||
| ) | |||
| CHANGE_EMAIL_LOCKOUT_DURATION: PositiveInt = Field( | |||
| description="Time (in seconds) a user must wait before retrying change email after exceeding the rate limit.", | |||
| default=86400, | |||
| ) | |||
| OWNER_TRANSFER_LOCKOUT_DURATION: PositiveInt = Field( | |||
| description="Time (in seconds) a user must wait before retrying owner transfer after exceeding the rate limit.", | |||
| default=86400, | |||
| ) | |||
| class ModerationConfig(BaseSettings): | |||
| """ | |||
| @@ -903,6 +949,7 @@ class FeatureConfig( | |||
| MultiModalTransferConfig, | |||
| PositionConfig, | |||
| RagEtlConfig, | |||
| RepositoryConfig, | |||
| SecurityConfig, | |||
| ToolConfig, | |||
| UpdateConfig, | |||
| @@ -162,6 +162,11 @@ class DatabaseConfig(BaseSettings): | |||
| default=3600, | |||
| ) | |||
| SQLALCHEMY_POOL_USE_LIFO: bool = Field( | |||
| description="If True, SQLAlchemy will use last-in-first-out way to retrieve connections from pool.", | |||
| default=False, | |||
| ) | |||
| SQLALCHEMY_POOL_PRE_PING: bool = Field( | |||
| description="If True, enables connection pool pre-ping feature to check connections.", | |||
| default=False, | |||
| @@ -199,6 +204,7 @@ class DatabaseConfig(BaseSettings): | |||
| "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, | |||
| "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, | |||
| "connect_args": connect_args, | |||
| "pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO, | |||
| } | |||
| @@ -12,6 +12,16 @@ class OTelConfig(BaseSettings): | |||
| default=False, | |||
| ) | |||
| OTLP_TRACE_ENDPOINT: str = Field( | |||
| description="OTLP trace endpoint", | |||
| default="", | |||
| ) | |||
| OTLP_METRIC_ENDPOINT: str = Field( | |||
| description="OTLP metric endpoint", | |||
| default="", | |||
| ) | |||
| OTLP_BASE_ENDPOINT: str = Field( | |||
| description="OTLP base endpoint", | |||
| default="http://localhost:4318", | |||
| @@ -151,6 +151,7 @@ class AppApi(Resource): | |||
| parser.add_argument("icon", type=str, location="json") | |||
| parser.add_argument("icon_background", type=str, location="json") | |||
| parser.add_argument("use_icon_as_answer_icon", type=bool, location="json") | |||
| parser.add_argument("max_active_requests", type=int, location="json") | |||
| args = parser.parse_args() | |||
| app_service = AppService() | |||
| @@ -35,16 +35,20 @@ class AppMCPServerController(Resource): | |||
| @get_app_model | |||
| @marshal_with(app_server_fields) | |||
| def post(self, app_model): | |||
| # The role of the current user in the ta table must be editor, admin, or owner | |||
| if not current_user.is_editor: | |||
| raise NotFound() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("description", type=str, required=True, location="json") | |||
| parser.add_argument("description", type=str, required=False, location="json") | |||
| parser.add_argument("parameters", type=dict, required=True, location="json") | |||
| args = parser.parse_args() | |||
| description = args.get("description") | |||
| if not description: | |||
| description = app_model.description or "" | |||
| server = AppMCPServer( | |||
| name=app_model.name, | |||
| description=args["description"], | |||
| description=description, | |||
| parameters=json.dumps(args["parameters"], ensure_ascii=False), | |||
| status=AppMCPServerStatus.ACTIVE, | |||
| app_id=app_model.id, | |||
| @@ -65,14 +69,22 @@ class AppMCPServerController(Resource): | |||
| raise NotFound() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("id", type=str, required=True, location="json") | |||
| parser.add_argument("description", type=str, required=True, location="json") | |||
| parser.add_argument("description", type=str, required=False, location="json") | |||
| parser.add_argument("parameters", type=dict, required=True, location="json") | |||
| parser.add_argument("status", type=str, required=False, location="json") | |||
| args = parser.parse_args() | |||
| server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() | |||
| if not server: | |||
| raise NotFound() | |||
| server.description = args["description"] | |||
| description = args.get("description") | |||
| if description is None: | |||
| pass | |||
| elif not description: | |||
| server.description = app_model.description or "" | |||
| else: | |||
| server.description = description | |||
| server.parameters = json.dumps(args["parameters"], ensure_ascii=False) | |||
| if args["status"]: | |||
| if args["status"] not in [status.value for status in AppMCPServerStatus]: | |||
| @@ -2,6 +2,7 @@ from datetime import datetime | |||
| from decimal import Decimal | |||
| import pytz | |||
| import sqlalchemy as sa | |||
| from flask import jsonify | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, reqparse | |||
| @@ -9,10 +10,11 @@ from flask_restful import Resource, reqparse | |||
| from controllers.console import api | |||
| from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from extensions.ext_database import db | |||
| from libs.helper import DatetimeString | |||
| from libs.login import login_required | |||
| from models.model import AppMode | |||
| from models import AppMode, Message | |||
| class DailyMessageStatistic(Resource): | |||
| @@ -85,46 +87,41 @@ class DailyConversationStatistic(Resource): | |||
| parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """SELECT | |||
| DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, | |||
| COUNT(DISTINCT messages.conversation_id) AS conversation_count | |||
| FROM | |||
| messages | |||
| WHERE | |||
| app_id = :app_id""" | |||
| arg_dict = {"tz": account.timezone, "app_id": app_model.id} | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| stmt = ( | |||
| sa.select( | |||
| sa.func.date( | |||
| sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz")) | |||
| ).label("date"), | |||
| sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"), | |||
| ) | |||
| .select_from(Message) | |||
| .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value) | |||
| ) | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += " AND created_at >= :start" | |||
| arg_dict["start"] = start_datetime_utc | |||
| stmt = stmt.where(Message.created_at >= start_datetime_utc) | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| stmt = stmt.where(Message.created_at < end_datetime_utc) | |||
| sql_query += " AND created_at < :end" | |||
| arg_dict["end"] = end_datetime_utc | |||
| sql_query += " GROUP BY date ORDER BY date" | |||
| stmt = stmt.group_by("date").order_by("date") | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) | |||
| rs = conn.execute(stmt, {"tz": account.timezone}) | |||
| for row in rs: | |||
| response_data.append({"date": str(row.date), "conversation_count": row.conversation_count}) | |||
| return jsonify({"data": response_data}) | |||
| @@ -68,13 +68,18 @@ def _create_pagination_parser(): | |||
| return parser | |||
| def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: | |||
| value_type = workflow_draft_var.value_type | |||
| return value_type.exposed_type().value | |||
| _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { | |||
| "id": fields.String, | |||
| "type": fields.String(attribute=lambda model: model.get_variable_type()), | |||
| "name": fields.String, | |||
| "description": fields.String, | |||
| "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), | |||
| "value_type": fields.String, | |||
| "value_type": fields.String(attribute=_serialize_variable_type), | |||
| "edited": fields.Boolean(attribute=lambda model: model.edited), | |||
| "visible": fields.Boolean, | |||
| } | |||
| @@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { | |||
| "name": fields.String, | |||
| "description": fields.String, | |||
| "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), | |||
| "value_type": fields.String, | |||
| "value_type": fields.String(attribute=_serialize_variable_type), | |||
| "edited": fields.Boolean(attribute=lambda model: model.edited), | |||
| "visible": fields.Boolean, | |||
| } | |||
| @@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource): | |||
| "name": v.name, | |||
| "description": v.description, | |||
| "selector": v.selector, | |||
| "value_type": v.value_type.value, | |||
| "value_type": v.value_type.exposed_type().value, | |||
| "value": v.value, | |||
| # Do not track edited for env vars. | |||
| "edited": False, | |||
| @@ -35,8 +35,6 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ | |||
| raise AppNotFoundError() | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode == AppMode.CHANNEL: | |||
| raise AppNotFoundError() | |||
| if mode is not None: | |||
| if isinstance(mode, list): | |||
| @@ -27,7 +27,19 @@ class InvalidTokenError(BaseHTTPException): | |||
| class PasswordResetRateLimitExceededError(BaseHTTPException): | |||
| error_code = "password_reset_rate_limit_exceeded" | |||
| description = "Too many password reset emails have been sent. Please try again in 1 minutes." | |||
| description = "Too many password reset emails have been sent. Please try again in 1 minute." | |||
| code = 429 | |||
| class EmailChangeRateLimitExceededError(BaseHTTPException): | |||
| error_code = "email_change_rate_limit_exceeded" | |||
| description = "Too many email change emails have been sent. Please try again in 1 minute." | |||
| code = 429 | |||
| class OwnerTransferRateLimitExceededError(BaseHTTPException): | |||
| error_code = "owner_transfer_rate_limit_exceeded" | |||
| description = "Too many owner transfer emails have been sent. Please try again in 1 minute." | |||
| code = 429 | |||
| @@ -65,3 +77,39 @@ class EmailPasswordResetLimitError(BaseHTTPException): | |||
| error_code = "email_password_reset_limit" | |||
| description = "Too many failed password reset attempts. Please try again in 24 hours." | |||
| code = 429 | |||
| class EmailChangeLimitError(BaseHTTPException): | |||
| error_code = "email_change_limit" | |||
| description = "Too many failed email change attempts. Please try again in 24 hours." | |||
| code = 429 | |||
| class EmailAlreadyInUseError(BaseHTTPException): | |||
| error_code = "email_already_in_use" | |||
| description = "A user with this email already exists." | |||
| code = 400 | |||
| class OwnerTransferLimitError(BaseHTTPException): | |||
| error_code = "owner_transfer_limit" | |||
| description = "Too many failed owner transfer attempts. Please try again in 24 hours." | |||
| code = 429 | |||
| class NotOwnerError(BaseHTTPException): | |||
| error_code = "not_owner" | |||
| description = "You are not the owner of the workspace." | |||
| code = 400 | |||
| class CannotTransferOwnerToSelfError(BaseHTTPException): | |||
| error_code = "cannot_transfer_owner_to_self" | |||
| description = "You cannot transfer ownership to yourself." | |||
| code = 400 | |||
| class MemberNotInTenantError(BaseHTTPException): | |||
| error_code = "member_not_in_tenant" | |||
| description = "The member is not in the workspace." | |||
| code = 400 | |||
| @@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException): | |||
| code = 415 | |||
| class HighQualityDatasetOnlyError(BaseHTTPException): | |||
| error_code = "high_quality_dataset_only" | |||
| description = "Current operation only supports 'high-quality' datasets." | |||
| code = 400 | |||
| class DatasetNotInitializedError(BaseHTTPException): | |||
| error_code = "dataset_not_initialized" | |||
| description = "The dataset is still being initialized or indexing. Please wait a moment." | |||
| @@ -4,10 +4,20 @@ import pytz | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, fields, marshal_with, reqparse | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from configs import dify_config | |||
| from constants.languages import supported_language | |||
| from controllers.console import api | |||
| from controllers.console.auth.error import ( | |||
| EmailAlreadyInUseError, | |||
| EmailChangeLimitError, | |||
| EmailCodeError, | |||
| InvalidEmailError, | |||
| InvalidTokenError, | |||
| ) | |||
| from controllers.console.error import AccountNotFound, EmailSendIpLimitError | |||
| from controllers.console.workspace.error import ( | |||
| AccountAlreadyInitedError, | |||
| CurrentPasswordIncorrectError, | |||
| @@ -18,15 +28,17 @@ from controllers.console.workspace.error import ( | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| cloud_edition_billing_enabled, | |||
| enable_change_email, | |||
| enterprise_license_required, | |||
| only_edition_cloud, | |||
| setup_required, | |||
| ) | |||
| from extensions.ext_database import db | |||
| from fields.member_fields import account_fields | |||
| from libs.helper import TimestampField, timezone | |||
| from libs.helper import TimestampField, email, extract_remote_ip, timezone | |||
| from libs.login import login_required | |||
| from models import AccountIntegrate, InvitationCode | |||
| from models.account import Account | |||
| from services.account_service import AccountService | |||
| from services.billing_service import BillingService | |||
| from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError | |||
| @@ -369,6 +381,134 @@ class EducationAutoCompleteApi(Resource): | |||
| return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) | |||
| class ChangeEmailSendEmailApi(Resource): | |||
| @enable_change_email | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("email", type=email, required=True, location="json") | |||
| parser.add_argument("language", type=str, required=False, location="json") | |||
| parser.add_argument("phase", type=str, required=False, location="json") | |||
| parser.add_argument("token", type=str, required=False, location="json") | |||
| args = parser.parse_args() | |||
| ip_address = extract_remote_ip(request) | |||
| if AccountService.is_email_send_ip_limit(ip_address): | |||
| raise EmailSendIpLimitError() | |||
| if args["language"] is not None and args["language"] == "zh-Hans": | |||
| language = "zh-Hans" | |||
| else: | |||
| language = "en-US" | |||
| account = None | |||
| user_email = args["email"] | |||
| if args["phase"] is not None and args["phase"] == "new_email": | |||
| if args["token"] is None: | |||
| raise InvalidTokenError() | |||
| reset_data = AccountService.get_change_email_data(args["token"]) | |||
| if reset_data is None: | |||
| raise InvalidTokenError() | |||
| user_email = reset_data.get("email", "") | |||
| if user_email != current_user.email: | |||
| raise InvalidEmailError() | |||
| else: | |||
| with Session(db.engine) as session: | |||
| account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() | |||
| if account is None: | |||
| raise AccountNotFound() | |||
| token = AccountService.send_change_email_email( | |||
| account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"] | |||
| ) | |||
| return {"result": "success", "data": token} | |||
| class ChangeEmailCheckApi(Resource): | |||
| @enable_change_email | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("email", type=email, required=True, location="json") | |||
| parser.add_argument("code", type=str, required=True, location="json") | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| user_email = args["email"] | |||
| is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"]) | |||
| if is_change_email_error_rate_limit: | |||
| raise EmailChangeLimitError() | |||
| token_data = AccountService.get_change_email_data(args["token"]) | |||
| if token_data is None: | |||
| raise InvalidTokenError() | |||
| if user_email != token_data.get("email"): | |||
| raise InvalidEmailError() | |||
| if args["code"] != token_data.get("code"): | |||
| AccountService.add_change_email_error_rate_limit(args["email"]) | |||
| raise EmailCodeError() | |||
| # Verified, revoke the first token | |||
| AccountService.revoke_change_email_token(args["token"]) | |||
| # Refresh token data by generating a new token | |||
| _, new_token = AccountService.generate_change_email_token( | |||
| user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={} | |||
| ) | |||
| AccountService.reset_change_email_error_rate_limit(args["email"]) | |||
| return {"is_valid": True, "email": token_data.get("email"), "token": new_token} | |||
| class ChangeEmailResetApi(Resource): | |||
| @enable_change_email | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("new_email", type=email, required=True, location="json") | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| reset_data = AccountService.get_change_email_data(args["token"]) | |||
| if not reset_data: | |||
| raise InvalidTokenError() | |||
| AccountService.revoke_change_email_token(args["token"]) | |||
| if not AccountService.check_email_unique(args["new_email"]): | |||
| raise EmailAlreadyInUseError() | |||
| old_email = reset_data.get("old_email", "") | |||
| if current_user.email != old_email: | |||
| raise AccountNotFound() | |||
| updated_account = AccountService.update_account(current_user, email=args["new_email"]) | |||
| return updated_account | |||
| class CheckEmailUnique(Resource): | |||
| @setup_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("email", type=email, required=True, location="json") | |||
| args = parser.parse_args() | |||
| if not AccountService.check_email_unique(args["email"]): | |||
| raise EmailAlreadyInUseError() | |||
| return {"result": "success"} | |||
| # Register API resources | |||
| api.add_resource(AccountInitApi, "/account/init") | |||
| api.add_resource(AccountProfileApi, "/account/profile") | |||
| @@ -385,5 +525,10 @@ api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback") | |||
| api.add_resource(EducationVerifyApi, "/account/education/verify") | |||
| api.add_resource(EducationApi, "/account/education") | |||
| api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete") | |||
| # Change email | |||
| api.add_resource(ChangeEmailSendEmailApi, "/account/change-email") | |||
| api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity") | |||
| api.add_resource(ChangeEmailResetApi, "/account/change-email/reset") | |||
| api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique") | |||
| # api.add_resource(AccountEmailApi, '/account/email') | |||
| # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') | |||
| @@ -13,12 +13,6 @@ class CurrentPasswordIncorrectError(BaseHTTPException): | |||
| code = 400 | |||
| class ProviderRequestFailedError(BaseHTTPException): | |||
| error_code = "provider_request_failed" | |||
| description = None | |||
| code = 400 | |||
| class InvalidInvitationCodeError(BaseHTTPException): | |||
| error_code = "invalid_invitation_code" | |||
| description = "Invalid invitation code." | |||
| @@ -1,22 +1,34 @@ | |||
| from urllib import parse | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, abort, marshal_with, reqparse | |||
| import services | |||
| from configs import dify_config | |||
| from controllers.console import api | |||
| from controllers.console.error import WorkspaceMembersLimitExceeded | |||
| from controllers.console.auth.error import ( | |||
| CannotTransferOwnerToSelfError, | |||
| EmailCodeError, | |||
| InvalidEmailError, | |||
| InvalidTokenError, | |||
| MemberNotInTenantError, | |||
| NotOwnerError, | |||
| OwnerTransferLimitError, | |||
| ) | |||
| from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| cloud_edition_billing_resource_check, | |||
| is_allow_transfer_owner, | |||
| setup_required, | |||
| ) | |||
| from extensions.ext_database import db | |||
| from fields.member_fields import account_with_role_list_fields | |||
| from libs.helper import extract_remote_ip | |||
| from libs.login import login_required | |||
| from models.account import Account, TenantAccountRole | |||
| from services.account_service import RegisterService, TenantService | |||
| from services.account_service import AccountService, RegisterService, TenantService | |||
| from services.errors.account import AccountAlreadyInTenantError | |||
| from services.feature_service import FeatureService | |||
| @@ -156,8 +168,146 @@ class DatasetOperatorMemberListApi(Resource): | |||
| return {"result": "success", "accounts": members}, 200 | |||
| class SendOwnerTransferEmailApi(Resource): | |||
| """Send owner transfer email.""" | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @is_allow_transfer_owner | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("language", type=str, required=False, location="json") | |||
| args = parser.parse_args() | |||
| ip_address = extract_remote_ip(request) | |||
| if AccountService.is_email_send_ip_limit(ip_address): | |||
| raise EmailSendIpLimitError() | |||
| # check if the current user is the owner of the workspace | |||
| if not TenantService.is_owner(current_user, current_user.current_tenant): | |||
| raise NotOwnerError() | |||
| if args["language"] is not None and args["language"] == "zh-Hans": | |||
| language = "zh-Hans" | |||
| else: | |||
| language = "en-US" | |||
| email = current_user.email | |||
| token = AccountService.send_owner_transfer_email( | |||
| account=current_user, | |||
| email=email, | |||
| language=language, | |||
| workspace_name=current_user.current_tenant.name, | |||
| ) | |||
| return {"result": "success", "data": token} | |||
| class OwnerTransferCheckApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @is_allow_transfer_owner | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("code", type=str, required=True, location="json") | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| # check if the current user is the owner of the workspace | |||
| if not TenantService.is_owner(current_user, current_user.current_tenant): | |||
| raise NotOwnerError() | |||
| user_email = current_user.email | |||
| is_owner_transfer_error_rate_limit = AccountService.is_owner_transfer_error_rate_limit(user_email) | |||
| if is_owner_transfer_error_rate_limit: | |||
| raise OwnerTransferLimitError() | |||
| token_data = AccountService.get_owner_transfer_data(args["token"]) | |||
| if token_data is None: | |||
| raise InvalidTokenError() | |||
| if user_email != token_data.get("email"): | |||
| raise InvalidEmailError() | |||
| if args["code"] != token_data.get("code"): | |||
| AccountService.add_owner_transfer_error_rate_limit(user_email) | |||
| raise EmailCodeError() | |||
| # Verified, revoke the first token | |||
| AccountService.revoke_owner_transfer_token(args["token"]) | |||
| # Refresh token data by generating a new token | |||
| _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={}) | |||
| AccountService.reset_owner_transfer_error_rate_limit(user_email) | |||
| return {"is_valid": True, "email": token_data.get("email"), "token": new_token} | |||
| class OwnerTransfer(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @is_allow_transfer_owner | |||
| def post(self, member_id): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| # check if the current user is the owner of the workspace | |||
| if not TenantService.is_owner(current_user, current_user.current_tenant): | |||
| raise NotOwnerError() | |||
| if current_user.id == str(member_id): | |||
| raise CannotTransferOwnerToSelfError() | |||
| transfer_token_data = AccountService.get_owner_transfer_data(args["token"]) | |||
| if not transfer_token_data: | |||
| raise InvalidTokenError() | |||
| if transfer_token_data.get("email") != current_user.email: | |||
| raise InvalidEmailError() | |||
| AccountService.revoke_owner_transfer_token(args["token"]) | |||
| member = db.session.get(Account, str(member_id)) | |||
| if not member: | |||
| abort(404) | |||
| else: | |||
| member_account = member | |||
| if not TenantService.is_member(member_account, current_user.current_tenant): | |||
| raise MemberNotInTenantError() | |||
| try: | |||
| assert member is not None, "Member not found" | |||
| TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user) | |||
| AccountService.send_new_owner_transfer_notify_email( | |||
| account=member, | |||
| email=member.email, | |||
| workspace_name=current_user.current_tenant.name, | |||
| ) | |||
| AccountService.send_old_owner_transfer_notify_email( | |||
| account=current_user, | |||
| email=current_user.email, | |||
| workspace_name=current_user.current_tenant.name, | |||
| new_owner_email=member.email, | |||
| ) | |||
| except Exception as e: | |||
| raise ValueError(str(e)) | |||
| return {"result": "success"} | |||
| api.add_resource(MemberListApi, "/workspaces/current/members") | |||
| api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") | |||
| api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>") | |||
| api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role") | |||
| api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") | |||
| # owner transfer | |||
| api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email") | |||
| api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check") | |||
| api.add_resource(OwnerTransfer, "/workspaces/current/members/<uuid:member_id>/owner-transfer") | |||
| @@ -235,3 +235,29 @@ def email_password_login_enabled(view): | |||
| abort(403) | |||
| return decorated | |||
| def enable_change_email(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| features = FeatureService.get_system_features() | |||
| if features.enable_change_email: | |||
| return view(*args, **kwargs) | |||
| # otherwise, return 403 | |||
| abort(403) | |||
| return decorated | |||
| def is_allow_transfer_owner(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.is_allow_transfer_workspace: | |||
| return view(*args, **kwargs) | |||
| # otherwise, return 403 | |||
| abort(403) | |||
| return decorated | |||
| @@ -3,7 +3,7 @@ import logging | |||
| from dateutil.parser import isoparse | |||
| from flask_restful import Resource, fields, marshal_with, reqparse | |||
| from flask_restful.inputs import int_range | |||
| from sqlalchemy.orm import Session | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from werkzeug.exceptions import InternalServerError | |||
| from controllers.service_api import api | |||
| @@ -30,7 +30,7 @@ from fields.workflow_app_log_fields import workflow_app_log_pagination_fields | |||
| from libs import helper | |||
| from libs.helper import TimestampField | |||
| from models.model import App, AppMode, EndUser | |||
| from models.workflow import WorkflowRun | |||
| from repositories.factory import DifyAPIRepositoryFactory | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.llm import InvokeRateLimitError | |||
| from services.workflow_app_service import WorkflowAppService | |||
| @@ -63,7 +63,15 @@ class WorkflowRunDetailApi(Resource): | |||
| if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]: | |||
| raise NotWorkflowAppError() | |||
| workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() | |||
| # Use repository to get workflow run | |||
| session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) | |||
| workflow_run = workflow_run_repo.get_workflow_run_by_id( | |||
| tenant_id=app_model.tenant_id, | |||
| app_id=app_model.id, | |||
| run_id=workflow_run_id, | |||
| ) | |||
| return workflow_run | |||
| @@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException): | |||
| code = 415 | |||
| class HighQualityDatasetOnlyError(BaseHTTPException): | |||
| error_code = "high_quality_dataset_only" | |||
| description = "Current operation only supports 'high-quality' datasets." | |||
| code = 400 | |||
| class DatasetNotInitializedError(BaseHTTPException): | |||
| error_code = "dataset_not_initialized" | |||
| description = "The dataset is still being initialized or indexing. Please wait a moment." | |||
| @@ -3,6 +3,8 @@ import logging | |||
| import uuid | |||
| from typing import Optional, Union, cast | |||
| from sqlalchemy import select | |||
| from core.agent.entities import AgentEntity, AgentToolEntity | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig | |||
| @@ -417,12 +419,15 @@ class BaseAgentRunner(AppRunner): | |||
| if isinstance(prompt_message, SystemPromptMessage): | |||
| result.append(prompt_message) | |||
| messages: list[Message] = ( | |||
| db.session.query(Message) | |||
| .filter( | |||
| Message.conversation_id == self.message.conversation_id, | |||
| messages = ( | |||
| ( | |||
| db.session.execute( | |||
| select(Message) | |||
| .where(Message.conversation_id == self.message.conversation_id) | |||
| .order_by(Message.created_at.desc()) | |||
| ) | |||
| ) | |||
| .order_by(Message.created_at.desc()) | |||
| .scalars() | |||
| .all() | |||
| ) | |||
| @@ -41,6 +41,7 @@ class AgentStrategyParameter(PluginParameter): | |||
| APP_SELECTOR = CommonParameterType.APP_SELECTOR.value | |||
| MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value | |||
| TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value | |||
| ANY = CommonParameterType.ANY.value | |||
| # deprecated, should not use. | |||
| SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value | |||
| @@ -25,8 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository | |||
| from core.repositories import DifyCoreRepositoryFactory | |||
| from core.workflow.repositories.draft_variable_repository import ( | |||
| DraftVariableSaverFactory, | |||
| ) | |||
| @@ -183,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING | |||
| else: | |||
| workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=workflow_triggered_from, | |||
| ) | |||
| # Create workflow node execution repository | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| @@ -260,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| # Create session factory | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| # Create workflow execution(aka workflow run) repository | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, | |||
| ) | |||
| # Create workflow node execution repository | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| @@ -343,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| # Create session factory | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| # Create workflow execution(aka workflow run) repository | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, | |||
| ) | |||
| # Create workflow node execution repository | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| @@ -16,9 +16,10 @@ from core.app.entities.queue_entities import ( | |||
| QueueTextChunkEvent, | |||
| ) | |||
| from core.moderation.base import ModerationError | |||
| from core.variables.variables import VariableUnion | |||
| from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import VariableLoader | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| @@ -64,7 +65,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| if not workflow: | |||
| raise ValueError("Workflow not initialized") | |||
| user_id = None | |||
| user_id: str | None = None | |||
| if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() | |||
| if end_user: | |||
| @@ -136,23 +137,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| session.commit() | |||
| # Create a variable pool. | |||
| system_inputs = { | |||
| SystemVariableKey.QUERY: query, | |||
| SystemVariableKey.FILES: files, | |||
| SystemVariableKey.CONVERSATION_ID: self.conversation.id, | |||
| SystemVariableKey.USER_ID: user_id, | |||
| SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, | |||
| SystemVariableKey.APP_ID: app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, | |||
| SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id, | |||
| } | |||
| system_inputs = SystemVariable( | |||
| query=query, | |||
| files=files, | |||
| conversation_id=self.conversation.id, | |||
| user_id=user_id, | |||
| dialogue_count=self._dialogue_count, | |||
| app_id=app_config.app_id, | |||
| workflow_id=app_config.workflow_id, | |||
| workflow_execution_id=self.application_generate_entity.workflow_run_id, | |||
| ) | |||
| # init variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| user_inputs=inputs, | |||
| environment_variables=workflow.environment_variables, | |||
| conversation_variables=conversation_variables, | |||
| # Based on the definition of `VariableUnion`, | |||
| # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. | |||
| conversation_variables=cast(list[VariableUnion], conversation_variables), | |||
| ) | |||
| # init graph | |||
| @@ -61,12 +61,12 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| @@ -116,16 +116,16 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| self._workflow_cycle_manager = WorkflowCycleManager( | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_system_variables={ | |||
| SystemVariableKey.QUERY: message.query, | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| SystemVariableKey.CONVERSATION_ID: conversation.id, | |||
| SystemVariableKey.USER_ID: user_session_id, | |||
| SystemVariableKey.DIALOGUE_COUNT: dialogue_count, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id, | |||
| }, | |||
| workflow_system_variables=SystemVariable( | |||
| query=message.query, | |||
| files=application_generate_entity.files, | |||
| conversation_id=conversation.id, | |||
| user_id=user_session_id, | |||
| dialogue_count=dialogue_count, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| workflow_id=workflow.id, | |||
| workflow_execution_id=application_generate_entity.workflow_run_id, | |||
| ), | |||
| workflow_info=CycleManagerWorkflowInfo( | |||
| workflow_id=workflow.id, | |||
| workflow_type=WorkflowType(workflow.type), | |||
| @@ -38,69 +38,6 @@ _logger = logging.getLogger(__name__) | |||
| class AppRunner: | |||
| def get_pre_calculate_rest_tokens( | |||
| self, | |||
| app_record: App, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| prompt_template_entity: PromptTemplateEntity, | |||
| inputs: Mapping[str, str], | |||
| files: Sequence["File"], | |||
| query: Optional[str] = None, | |||
| ) -> int: | |||
| """ | |||
| Get pre calculate rest tokens | |||
| :param app_record: app record | |||
| :param model_config: model config entity | |||
| :param prompt_template_entity: prompt template entity | |||
| :param inputs: inputs | |||
| :param files: files | |||
| :param query: query | |||
| :return: | |||
| """ | |||
| # Invoke model | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=model_config.provider_model_bundle, model=model_config.model | |||
| ) | |||
| model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) | |||
| max_tokens = 0 | |||
| for parameter_rule in model_config.model_schema.parameter_rules: | |||
| if parameter_rule.name == "max_tokens" or ( | |||
| parameter_rule.use_template and parameter_rule.use_template == "max_tokens" | |||
| ): | |||
| max_tokens = ( | |||
| model_config.parameters.get(parameter_rule.name) | |||
| or model_config.parameters.get(parameter_rule.use_template or "") | |||
| ) or 0 | |||
| if model_context_tokens is None: | |||
| return -1 | |||
| if max_tokens is None: | |||
| max_tokens = 0 | |||
| # get prompt messages without memory and context | |||
| prompt_messages, stop = self.organize_prompt_messages( | |||
| app_record=app_record, | |||
| model_config=model_config, | |||
| prompt_template_entity=prompt_template_entity, | |||
| inputs=inputs, | |||
| files=files, | |||
| query=query, | |||
| ) | |||
| prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) | |||
| rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens | |||
| if rest_tokens < 0: | |||
| raise InvokeBadRequestError( | |||
| "Query or prefix prompt is too long, you can reduce the prefix prompt, " | |||
| "or shrink the max token, or switch to a llm with a larger token limit size." | |||
| ) | |||
| return rest_tokens | |||
| def recalc_llm_max_tokens( | |||
| self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] | |||
| ): | |||
| @@ -23,8 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat | |||
| from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository | |||
| from core.repositories import DifyCoreRepositoryFactory | |||
| from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| @@ -156,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING | |||
| else: | |||
| workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=workflow_triggered_from, | |||
| ) | |||
| # Create workflow node execution repository | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| @@ -306,16 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| # Create session factory | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| # Create workflow execution(aka workflow run) repository | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, | |||
| ) | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| @@ -390,16 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| # Create session factory | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| # Create workflow execution(aka workflow run) repository | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, | |||
| ) | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| @@ -11,7 +11,7 @@ from core.app.entities.app_invoke_entities import ( | |||
| ) | |||
| from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import VariableLoader | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| @@ -95,13 +95,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| files = self.application_generate_entity.files | |||
| # Create a variable pool. | |||
| system_inputs = { | |||
| SystemVariableKey.FILES: files, | |||
| SystemVariableKey.USER_ID: user_id, | |||
| SystemVariableKey.APP_ID: app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, | |||
| SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id, | |||
| } | |||
| system_inputs = SystemVariable( | |||
| files=files, | |||
| user_id=user_id, | |||
| app_id=app_config.app_id, | |||
| workflow_id=app_config.workflow_id, | |||
| workflow_execution_id=self.application_generate_entity.workflow_execution_id, | |||
| ) | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| @@ -3,7 +3,6 @@ import time | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| @@ -55,10 +54,10 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| @@ -68,7 +67,6 @@ from models.workflow import ( | |||
| Workflow, | |||
| WorkflowAppLog, | |||
| WorkflowAppLogCreatedFrom, | |||
| WorkflowRun, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| @@ -109,13 +107,13 @@ class WorkflowAppGenerateTaskPipeline: | |||
| self._workflow_cycle_manager = WorkflowCycleManager( | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_system_variables={ | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| SystemVariableKey.USER_ID: user_session_id, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id, | |||
| }, | |||
| workflow_system_variables=SystemVariable( | |||
| files=application_generate_entity.files, | |||
| user_id=user_session_id, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| workflow_id=workflow.id, | |||
| workflow_execution_id=application_generate_entity.workflow_execution_id, | |||
| ), | |||
| workflow_info=CycleManagerWorkflowInfo( | |||
| workflow_id=workflow.id, | |||
| workflow_type=WorkflowType(workflow.type), | |||
| @@ -562,8 +560,6 @@ class WorkflowAppGenerateTaskPipeline: | |||
| tts_publisher.publish(None) | |||
| def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: | |||
| workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_)) | |||
| assert workflow_run is not None | |||
| invoke_from = self._application_generate_entity.invoke_from | |||
| if invoke_from == InvokeFrom.SERVICE_API: | |||
| created_from = WorkflowAppLogCreatedFrom.SERVICE_API | |||
| @@ -576,10 +572,10 @@ class WorkflowAppGenerateTaskPipeline: | |||
| return | |||
| workflow_app_log = WorkflowAppLog() | |||
| workflow_app_log.tenant_id = workflow_run.tenant_id | |||
| workflow_app_log.app_id = workflow_run.app_id | |||
| workflow_app_log.workflow_id = workflow_run.workflow_id | |||
| workflow_app_log.workflow_run_id = workflow_run.id | |||
| workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id | |||
| workflow_app_log.app_id = self._application_generate_entity.app_config.app_id | |||
| workflow_app_log.workflow_id = workflow_execution.workflow_id | |||
| workflow_app_log.workflow_run_id = workflow_execution.id_ | |||
| workflow_app_log.created_from = created_from.value | |||
| workflow_app_log.created_by_role = self._created_by_role | |||
| workflow_app_log.created_by = self._user_id | |||
| @@ -62,6 +62,7 @@ from core.workflow.graph_engine.entities.event import ( | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| @@ -166,7 +167,7 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| # init variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables={}, | |||
| system_variables=SystemVariable.empty(), | |||
| user_inputs={}, | |||
| environment_variables=workflow.environment_variables, | |||
| ) | |||
| @@ -263,7 +264,7 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| # init variable pool | |||
| variable_pool = VariablePool( | |||
| system_variables={}, | |||
| system_variables=SystemVariable.empty(), | |||
| user_inputs={}, | |||
| environment_variables=workflow.environment_variables, | |||
| ) | |||
| @@ -10,8 +10,3 @@ class RecordNotFoundError(TaskPipilineError): | |||
| class WorkflowRunNotFoundError(RecordNotFoundError): | |||
| def __init__(self, workflow_run_id: str): | |||
| super().__init__("WorkflowRun", workflow_run_id) | |||
| class WorkflowNodeExecutionNotFoundError(RecordNotFoundError): | |||
| def __init__(self, workflow_node_execution_id: str): | |||
| super().__init__("WorkflowNodeExecution", workflow_node_execution_id) | |||
| @@ -14,6 +14,7 @@ class CommonParameterType(StrEnum): | |||
| APP_SELECTOR = "app-selector" | |||
| MODEL_SELECTOR = "model-selector" | |||
| TOOLS_SELECTOR = "array[tools]" | |||
| ANY = "any" | |||
| # Dynamic select parameter | |||
| # Once you are not sure about the available options until authorization is done | |||
| @@ -7,13 +7,6 @@ if TYPE_CHECKING: | |||
| _tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None | |||
| class ToolFileParser: | |||
| @staticmethod | |||
| def get_tool_file_manager() -> "ToolFileManager": | |||
| assert _tool_file_manager_factory is not None | |||
| return _tool_file_manager_factory() | |||
| def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: | |||
| global _tool_file_manager_factory | |||
| _tool_file_manager_factory = factory | |||
| @@ -5,6 +5,8 @@ from base64 import b64encode | |||
| from collections.abc import Mapping | |||
| from typing import Any | |||
| from core.variables.utils import SegmentJSONEncoder | |||
| class TemplateTransformer(ABC): | |||
| _code_placeholder: str = "{{code}}" | |||
| @@ -43,17 +45,13 @@ class TemplateTransformer(ABC): | |||
| result_str = cls.extract_result_str_from_response(response) | |||
| result = json.loads(result_str) | |||
| except json.JSONDecodeError as e: | |||
| raise ValueError(f"Failed to parse JSON response: {str(e)}. Response content: {result_str[:200]}...") | |||
| raise ValueError(f"Failed to parse JSON response: {str(e)}.") | |||
| except ValueError as e: | |||
| # Re-raise ValueError from extract_result_str_from_response | |||
| raise e | |||
| except Exception as e: | |||
| raise ValueError(f"Unexpected error during response transformation: {str(e)}") | |||
| # Check if the result contains an error | |||
| if isinstance(result, dict) and "error" in result: | |||
| raise ValueError(f"JavaScript execution error: {result['error']}") | |||
| if not isinstance(result, dict): | |||
| raise ValueError(f"Result must be a dict, got {type(result).__name__}") | |||
| if not all(isinstance(k, str) for k in result): | |||
| @@ -95,7 +93,7 @@ class TemplateTransformer(ABC): | |||
| @classmethod | |||
| def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: | |||
| inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode() | |||
| inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode() | |||
| input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") | |||
| return input_base64_encoded | |||
| @@ -1,52 +0,0 @@ | |||
| import base64 | |||
| import hashlib | |||
| import hmac | |||
| import os | |||
| import time | |||
| from pydantic import BaseModel, Field | |||
| from configs import dify_config | |||
| class SignedUrlParams(BaseModel): | |||
| sign_key: str = Field(..., description="The sign key") | |||
| timestamp: str = Field(..., description="Timestamp") | |||
| nonce: str = Field(..., description="Nonce") | |||
| sign: str = Field(..., description="Signature") | |||
| class UrlSigner: | |||
| @classmethod | |||
| def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str: | |||
| signed_url_params = cls.get_signed_url_params(sign_key, prefix) | |||
| return ( | |||
| f"{url}?timestamp={signed_url_params.timestamp}" | |||
| f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}" | |||
| ) | |||
| @classmethod | |||
| def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams: | |||
| timestamp = str(int(time.time())) | |||
| nonce = os.urandom(16).hex() | |||
| sign = cls._sign(sign_key, timestamp, nonce, prefix) | |||
| return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign) | |||
| @classmethod | |||
| def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool: | |||
| recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix) | |||
| return sign == recalculated_sign | |||
| @classmethod | |||
| def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str: | |||
| if not dify_config.SECRET_KEY: | |||
| raise Exception("SECRET_KEY is not set") | |||
| data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}" | |||
| secret_key = dify_config.SECRET_KEY.encode() | |||
| sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |||
| return encoded_sign | |||
| @@ -148,9 +148,11 @@ class LLMGenerator: | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_default_model_instance( | |||
| model_instance = model_manager.get_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.LLM, | |||
| provider=model_config.get("provider", ""), | |||
| model=model_config.get("name", ""), | |||
| ) | |||
| try: | |||
| @@ -240,7 +240,7 @@ def refresh_authorization( | |||
| response = requests.post(token_url, data=params) | |||
| if not response.ok: | |||
| raise ValueError(f"Token refresh failed: HTTP {response.status_code}") | |||
| return OAuthTokens.parse_obj(response.json()) | |||
| return OAuthTokens.model_validate(response.json()) | |||
| def register_client( | |||
| @@ -148,9 +148,7 @@ class MCPServerStreamableHTTPRequestHandler: | |||
| if not self.end_user: | |||
| raise ValueError("User not found") | |||
| request = cast(types.CallToolRequest, self.request.root) | |||
| args = request.params.arguments | |||
| if not args: | |||
| raise ValueError("No arguments provided") | |||
| args = request.params.arguments or {} | |||
| if self.app.mode in {AppMode.WORKFLOW.value}: | |||
| args = {"inputs": args} | |||
| elif self.app.mode in {AppMode.COMPLETION.value}: | |||
| @@ -1,7 +1,7 @@ | |||
| import logging | |||
| import queue | |||
| from collections.abc import Callable | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError | |||
| from contextlib import ExitStack | |||
| from datetime import timedelta | |||
| from types import TracebackType | |||
| @@ -171,23 +171,41 @@ class BaseSession( | |||
| self._session_read_timeout_seconds = read_timeout_seconds | |||
| self._in_flight = {} | |||
| self._exit_stack = ExitStack() | |||
| # Initialize executor and future to None for proper cleanup checks | |||
| self._executor: ThreadPoolExecutor | None = None | |||
| self._receiver_future: Future | None = None | |||
| def __enter__(self) -> Self: | |||
| self._executor = ThreadPoolExecutor() | |||
| # The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1 | |||
| # ensures no unnecessary threads are created. | |||
| self._executor = ThreadPoolExecutor(max_workers=1) | |||
| self._receiver_future = self._executor.submit(self._receive_loop) | |||
| return self | |||
| def check_receiver_status(self) -> None: | |||
| if self._receiver_future.done(): | |||
| """`check_receiver_status` ensures that any exceptions raised during the | |||
| execution of `_receive_loop` are retrieved and propagated.""" | |||
| if self._receiver_future and self._receiver_future.done(): | |||
| self._receiver_future.result() | |||
| def __exit__( | |||
| self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None | |||
| ) -> None: | |||
| self._exit_stack.close() | |||
| self._read_stream.put(None) | |||
| self._write_stream.put(None) | |||
| # Wait for the receiver loop to finish | |||
| if self._receiver_future: | |||
| try: | |||
| self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds | |||
| except TimeoutError: | |||
| # If the receiver loop is still running after timeout, we'll force shutdown | |||
| pass | |||
| # Shutdown the executor | |||
| if self._executor: | |||
| self._executor.shutdown(wait=True) | |||
| def send_request( | |||
| self, | |||
| request: SendRequestT, | |||
| @@ -1,6 +1,8 @@ | |||
| from collections.abc import Sequence | |||
| from typing import Optional | |||
| from sqlalchemy import select | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| from core.file import file_manager | |||
| from core.model_manager import ModelInstance | |||
| @@ -17,11 +19,15 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models.model import AppMode, Conversation, Message, MessageFile | |||
| from models.workflow import WorkflowRun | |||
| from models.workflow import Workflow, WorkflowRun | |||
| class TokenBufferMemory: | |||
| def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None: | |||
| def __init__( | |||
| self, | |||
| conversation: Conversation, | |||
| model_instance: ModelInstance, | |||
| ) -> None: | |||
| self.conversation = conversation | |||
| self.model_instance = model_instance | |||
| @@ -36,20 +42,8 @@ class TokenBufferMemory: | |||
| app_record = self.conversation.app | |||
| # fetch limited messages, and return reversed | |||
| query = ( | |||
| db.session.query( | |||
| Message.id, | |||
| Message.query, | |||
| Message.answer, | |||
| Message.created_at, | |||
| Message.workflow_run_id, | |||
| Message.parent_message_id, | |||
| Message.answer_tokens, | |||
| ) | |||
| .filter( | |||
| Message.conversation_id == self.conversation.id, | |||
| ) | |||
| .order_by(Message.created_at.desc()) | |||
| stmt = ( | |||
| select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc()) | |||
| ) | |||
| if message_limit and message_limit > 0: | |||
| @@ -57,7 +51,9 @@ class TokenBufferMemory: | |||
| else: | |||
| message_limit = 500 | |||
| messages = query.limit(message_limit).all() | |||
| stmt = stmt.limit(message_limit) | |||
| messages = db.session.scalars(stmt).all() | |||
| # instead of all messages from the conversation, we only need to extract messages | |||
| # that belong to the thread of last message | |||
| @@ -74,18 +70,20 @@ class TokenBufferMemory: | |||
| files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() | |||
| if files: | |||
| file_extra_config = None | |||
| if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: | |||
| if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: | |||
| file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) | |||
| elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: | |||
| workflow_run = db.session.scalar( | |||
| select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id) | |||
| ) | |||
| if not workflow_run: | |||
| raise ValueError(f"Workflow run not found: {message.workflow_run_id}") | |||
| workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) | |||
| if not workflow: | |||
| raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") | |||
| file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) | |||
| else: | |||
| if message.workflow_run_id: | |||
| workflow_run = ( | |||
| db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() | |||
| ) | |||
| if workflow_run and workflow_run.workflow: | |||
| file_extra_config = FileUploadConfigManager.convert( | |||
| workflow_run.workflow.features_dict, is_vision=False | |||
| ) | |||
| raise AssertionError(f"Invalid app mode: {self.conversation.mode}") | |||
| detail = ImagePromptMessageContent.DETAIL.LOW | |||
| if file_extra_config and app_record: | |||
| @@ -284,7 +284,8 @@ class AliyunDataTrace(BaseTraceInstance): | |||
| else: | |||
| node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution) | |||
| return node_span | |||
| except Exception: | |||
| except Exception as e: | |||
| logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True) | |||
| return None | |||
| def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status: | |||
| @@ -306,7 +307,7 @@ class AliyunDataTrace(BaseTraceInstance): | |||
| start_time=convert_datetime_to_nanoseconds(node_execution.created_at), | |||
| end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), | |||
| attributes={ | |||
| GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), | |||
| GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", | |||
| GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value, | |||
| GEN_AI_FRAMEWORK: "dify", | |||
| INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False), | |||
| @@ -381,7 +382,7 @@ class AliyunDataTrace(BaseTraceInstance): | |||
| start_time=convert_datetime_to_nanoseconds(node_execution.created_at), | |||
| end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), | |||
| attributes={ | |||
| GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), | |||
| GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", | |||
| GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, | |||
| GEN_AI_FRAMEWORK: "dify", | |||
| GEN_AI_MODEL_NAME: process_data.get("model_name", ""), | |||
| @@ -415,7 +416,7 @@ class AliyunDataTrace(BaseTraceInstance): | |||
| start_time=convert_datetime_to_nanoseconds(trace_info.start_time), | |||
| end_time=convert_datetime_to_nanoseconds(trace_info.end_time), | |||
| attributes={ | |||
| GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), | |||
| GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", | |||
| GEN_AI_USER_ID: str(user_id), | |||
| GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, | |||
| GEN_AI_FRAMEWORK: "dify", | |||
| @@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( | |||
| UnitEnum, | |||
| ) | |||
| from core.ops.utils import filter_none_values | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories import DifyCoreRepositoryFactory | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| from models import EndUser, WorkflowNodeExecutionTriggeredFrom | |||
| @@ -123,10 +123,10 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| service_account = self.get_service_account_with_tenant(app_id) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=session_factory, | |||
| user=service_account, | |||
| app_id=trace_info.metadata.get("app_id"), | |||
| app_id=app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| @@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( | |||
| LangSmithRunUpdateModel, | |||
| ) | |||
| from core.ops.utils import filter_none_values, generate_dotted_order | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories import DifyCoreRepositoryFactory | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| @@ -145,10 +145,10 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| service_account = self.get_service_account_with_tenant(app_id) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=session_factory, | |||
| user=service_account, | |||
| app_id=trace_info.metadata.get("app_id"), | |||
| app_id=app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( | |||
| TraceTaskName, | |||
| WorkflowTraceInfo, | |||
| ) | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories import DifyCoreRepositoryFactory | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| @@ -160,10 +160,10 @@ class OpikDataTrace(BaseTraceInstance): | |||
| service_account = self.get_service_account_with_tenant(app_id) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=session_factory, | |||
| user=service_account, | |||
| app_id=trace_info.metadata.get("app_id"), | |||
| app_id=app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| @@ -241,7 +241,7 @@ class OpikDataTrace(BaseTraceInstance): | |||
| "trace_id": opik_trace_id, | |||
| "id": prepare_opik_uuid(created_at, node_execution_id), | |||
| "parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id), | |||
| "name": node_type, | |||
| "name": node_name, | |||
| "type": run_type, | |||
| "start_time": created_at, | |||
| "end_time": finished_at, | |||
| @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( | |||
| WorkflowTraceInfo, | |||
| ) | |||
| from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories import DifyCoreRepositoryFactory | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| @@ -144,10 +144,10 @@ class WeaveDataTrace(BaseTraceInstance): | |||
| service_account = self.get_service_account_with_tenant(app_id) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=session_factory, | |||
| user=service_account, | |||
| app_id=trace_info.metadata.get("app_id"), | |||
| app_id=app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator | |||
| from core.entities.parameter_entities import CommonParameterType | |||
| from core.tools.entities.common_entities import I18nObject | |||
| from core.workflow.nodes.base.entities import NumberType | |||
| class PluginParameterOption(BaseModel): | |||
| @@ -38,6 +39,7 @@ class PluginParameterType(enum.StrEnum): | |||
| APP_SELECTOR = CommonParameterType.APP_SELECTOR.value | |||
| MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value | |||
| TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value | |||
| ANY = CommonParameterType.ANY.value | |||
| DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value | |||
| # deprecated, should not use. | |||
| @@ -151,6 +153,10 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): | |||
| if value and not isinstance(value, list): | |||
| raise ValueError("The tools selector must be a list.") | |||
| return value | |||
| case PluginParameterType.ANY: | |||
| if value and not isinstance(value, str | dict | list | NumberType): | |||
| raise ValueError("The var selector must be a string, dictionary, list or number.") | |||
| return value | |||
| case PluginParameterType.ARRAY: | |||
| if not isinstance(value, list): | |||
| # Try to parse JSON string for arrays | |||
| @@ -141,17 +141,6 @@ class PluginEntity(PluginInstallation): | |||
| return self | |||
| class GithubPackage(BaseModel): | |||
| repo: str | |||
| version: str | |||
| package: str | |||
| class GithubVersion(BaseModel): | |||
| repo: str | |||
| version: str | |||
| class GenericProviderID: | |||
| organization: str | |||
| plugin_name: str | |||
| @@ -36,7 +36,7 @@ class PluginInstaller(BasePluginClient): | |||
| "GET", | |||
| f"plugin/{tenant_id}/management/list", | |||
| PluginListResponse, | |||
| params={"page": 1, "page_size": 256}, | |||
| params={"page": 1, "page_size": 256, "response_type": "paged"}, | |||
| ) | |||
| return result.list | |||
| @@ -45,7 +45,7 @@ class PluginInstaller(BasePluginClient): | |||
| "GET", | |||
| f"plugin/{tenant_id}/management/list", | |||
| PluginListResponse, | |||
| params={"page": page, "page_size": page_size}, | |||
| params={"page": page, "page_size": page_size, "response_type": "paged"}, | |||
| ) | |||
| def upload_pkg( | |||
| @@ -158,7 +158,7 @@ class AdvancedPromptTransform(PromptTransform): | |||
| if prompt_item.edition_type == "basic" or not prompt_item.edition_type: | |||
| if self.with_variable_tmpl: | |||
| vp = VariablePool() | |||
| vp = VariablePool.empty() | |||
| for k, v in inputs.items(): | |||
| if k.startswith("#"): | |||
| vp.add(k[1:-1].split("."), v) | |||
| @@ -1,10 +1,11 @@ | |||
| from typing import Any | |||
| from collections.abc import Sequence | |||
| from constants import UUID_NIL | |||
| from models import Message | |||
| def extract_thread_messages(messages: list[Any]): | |||
| thread_messages = [] | |||
| def extract_thread_messages(messages: Sequence[Message]): | |||
| thread_messages: list[Message] = [] | |||
| next_message = None | |||
| for message in messages: | |||
| @@ -1,3 +1,5 @@ | |||
| from sqlalchemy import select | |||
| from core.prompt.utils.extract_thread_messages import extract_thread_messages | |||
| from extensions.ext_database import db | |||
| from models.model import Message | |||
| @@ -8,19 +10,9 @@ def get_thread_messages_length(conversation_id: str) -> int: | |||
| Get the number of thread messages based on the parent message id. | |||
| """ | |||
| # Fetch all messages related to the conversation | |||
| query = ( | |||
| db.session.query( | |||
| Message.id, | |||
| Message.parent_message_id, | |||
| Message.answer, | |||
| ) | |||
| .filter( | |||
| Message.conversation_id == conversation_id, | |||
| ) | |||
| .order_by(Message.created_at.desc()) | |||
| ) | |||
| stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc()) | |||
| messages = query.all() | |||
| messages = db.session.scalars(stmt).all() | |||
| # Extract thread messages | |||
| thread_messages = extract_thread_messages(messages) | |||
| @@ -1,12 +0,0 @@ | |||
| """Abstract interface for document clean implementations.""" | |||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||
| class UnstructuredNonAsciiCharsCleaner(BaseCleaner): | |||
| def clean(self, content) -> str: | |||
| """clean document content.""" | |||
| from unstructured.cleaners.core import clean_extra_whitespace | |||
| # Returns "ITEM 1A: RISK FACTORS" | |||
| return clean_extra_whitespace(content) | |||
| @@ -1,15 +0,0 @@ | |||
| """Abstract interface for document clean implementations.""" | |||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||
| class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): | |||
| def clean(self, content) -> str: | |||
| """clean document content.""" | |||
| import re | |||
| from unstructured.cleaners.core import group_broken_paragraphs | |||
| para_split_re = re.compile(r"(\s*\n\s*){3}") | |||
| return group_broken_paragraphs(content, paragraph_split=para_split_re) | |||
| @@ -1,12 +0,0 @@ | |||
| """Abstract interface for document clean implementations.""" | |||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||
| class UnstructuredNonAsciiCharsCleaner(BaseCleaner): | |||
| def clean(self, content) -> str: | |||
| """clean document content.""" | |||
| from unstructured.cleaners.core import clean_non_ascii_chars | |||
| # Returns "This text contains non-ascii characters!" | |||
| return clean_non_ascii_chars(content) | |||
| @@ -1,12 +0,0 @@ | |||
| """Abstract interface for document clean implementations.""" | |||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||
| class UnstructuredNonAsciiCharsCleaner(BaseCleaner): | |||
| def clean(self, content) -> str: | |||
| """Replaces unicode quote characters, such as the \x91 character in a string.""" | |||
| from unstructured.cleaners.core import replace_unicode_quotes | |||
| return replace_unicode_quotes(content) | |||
| @@ -1,11 +0,0 @@ | |||
| """Abstract interface for document clean implementations.""" | |||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||
| class UnstructuredTranslateTextCleaner(BaseCleaner): | |||
| def clean(self, content) -> str: | |||
| """clean document content.""" | |||
| from unstructured.cleaners.translate import translate_text | |||
| return translate_text(content) | |||
| @@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor | |||
| from typing import Optional | |||
| from flask import Flask, current_app | |||
| from sqlalchemy.orm import load_only | |||
| from sqlalchemy.orm import Session, load_only | |||
| from configs import dify_config | |||
| from core.rag.data_post_processor.data_post_processor import DataPostProcessor | |||
| @@ -144,7 +144,8 @@ class RetrievalService: | |||
| @classmethod | |||
| def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: | |||
| return db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| with Session(db.engine) as session: | |||
| return session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| @classmethod | |||
| def keyword_search( | |||
| @@ -4,6 +4,7 @@ from typing import Any, Optional | |||
| import tablestore # type: ignore | |||
| from pydantic import BaseModel, model_validator | |||
| from tablestore import BatchGetRowRequest, TableInBatchGetRowItem | |||
| from configs import dify_config | |||
| from core.rag.datasource.vdb.field import Field | |||
| @@ -50,6 +51,29 @@ class TableStoreVector(BaseVector): | |||
| self._index_name = f"{collection_name}_idx" | |||
| self._tags_field = f"{Field.METADATA_KEY.value}_tags" | |||
| def create_collection(self, embeddings: list[list[float]], **kwargs): | |||
| dimension = len(embeddings[0]) | |||
| self._create_collection(dimension) | |||
| def get_by_ids(self, ids: list[str]) -> list[Document]: | |||
| docs = [] | |||
| request = BatchGetRowRequest() | |||
| columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value] | |||
| rows_to_get = [[("id", _id)] for _id in ids] | |||
| request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1)) | |||
| result = self._tablestore_client.batch_get_row(request) | |||
| table_result = result.get_result_by_table(self._table_name) | |||
| for item in table_result: | |||
| if item.is_ok and item.row: | |||
| kv = {k: v for k, v, t in item.row.attribute_columns} | |||
| docs.append( | |||
| Document( | |||
| page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value]) | |||
| ) | |||
| ) | |||
| return docs | |||
| def get_type(self) -> str: | |||
| return VectorType.TABLESTORE | |||
| @@ -1,17 +0,0 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel | |||
| class ClusterEntity(BaseModel): | |||
| """ | |||
| Model Config Entity. | |||
| """ | |||
| name: str | |||
| cluster_id: str | |||
| displayName: str | |||
| region: str | |||
| spendingLimit: Optional[int] = 1000 | |||
| version: str | |||
| createdBy: str | |||
| @@ -9,8 +9,7 @@ from __future__ import annotations | |||
| import contextlib | |||
| import mimetypes | |||
| from abc import ABC, abstractmethod | |||
| from collections.abc import Generator, Iterable, Mapping | |||
| from collections.abc import Generator, Mapping | |||
| from io import BufferedReader, BytesIO | |||
| from pathlib import Path, PurePath | |||
| from typing import Any, Optional, Union | |||
| @@ -143,21 +142,3 @@ class Blob(BaseModel): | |||
| if self.source: | |||
| str_repr += f" {self.source}" | |||
| return str_repr | |||
| class BlobLoader(ABC): | |||
| """Abstract interface for blob loaders implementation. | |||
| Implementer should be able to load raw content from a datasource system according | |||
| to some criteria and return the raw content lazily as a stream of blobs. | |||
| """ | |||
| @abstractmethod | |||
| def yield_blobs( | |||
| self, | |||
| ) -> Iterable[Blob]: | |||
| """A lazy loader for raw data represented by Blob object. | |||
| Returns: | |||
| A generator over blobs | |||
| """ | |||
| @@ -1,47 +0,0 @@ | |||
| import logging | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class UnstructuredPDFExtractor(BaseExtractor): | |||
| """Load pdf files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| api_url: Unstructured API URL | |||
| api_key: Unstructured API Key | |||
| """ | |||
| def __init__(self, file_path: str, api_url: str, api_key: str): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| self._api_key = api_key | |||
| def extract(self) -> list[Document]: | |||
| if self._api_url: | |||
| from unstructured.partition.api import partition_via_api | |||
| elements = partition_via_api( | |||
| filename=self._file_path, api_url=self._api_url, api_key=self._api_key, strategy="auto" | |||
| ) | |||
| else: | |||
| from unstructured.partition.pdf import partition_pdf | |||
| elements = partition_pdf(filename=self._file_path, strategy="auto") | |||
| from unstructured.chunking.title import chunk_by_title | |||
| chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) | |||
| documents = [] | |||
| for chunk in chunks: | |||
| text = chunk.text.strip() | |||
| documents.append(Document(page_content=text)) | |||
| return documents | |||
| @@ -1,34 +0,0 @@ | |||
| import logging | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class UnstructuredTextExtractor(BaseExtractor): | |||
| """Load msg files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| """ | |||
| def __init__(self, file_path: str, api_url: str): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def extract(self) -> list[Document]: | |||
| from unstructured.partition.text import partition_text | |||
| elements = partition_text(filename=self._file_path) | |||
| from unstructured.chunking.title import chunk_by_title | |||
| chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) | |||
| documents = [] | |||
| for chunk in chunks: | |||
| text = chunk.text.strip() | |||
| documents.append(Document(page_content=text)) | |||
| return documents | |||
| @@ -9,6 +9,7 @@ from typing import Any, Optional, Union, cast | |||
| from flask import Flask, current_app | |||
| from sqlalchemy import Float, and_, or_, text | |||
| from sqlalchemy import cast as sqlalchemy_cast | |||
| from sqlalchemy.orm import Session | |||
| from core.app.app_config.entities import ( | |||
| DatasetEntity, | |||
| @@ -598,7 +599,8 @@ class DatasetRetrieval: | |||
| metadata_condition: Optional[MetadataCondition] = None, | |||
| ): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| with Session(db.engine) as session: | |||
| dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| return [] | |||
| @@ -10,7 +10,6 @@ from typing import ( | |||
| Any, | |||
| Literal, | |||
| Optional, | |||
| TypedDict, | |||
| TypeVar, | |||
| Union, | |||
| ) | |||
| @@ -168,167 +167,6 @@ class TextSplitter(BaseDocumentTransformer, ABC): | |||
| raise NotImplementedError | |||
| class CharacterTextSplitter(TextSplitter): | |||
| """Splitting text that looks at characters.""" | |||
| def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: | |||
| """Create a new TextSplitter.""" | |||
| super().__init__(**kwargs) | |||
| self._separator = separator | |||
| def split_text(self, text: str) -> list[str]: | |||
| """Split incoming text and return chunks.""" | |||
| # First we naively split the large input into a bunch of smaller ones. | |||
| splits = _split_text_with_regex(text, self._separator, self._keep_separator) | |||
| _separator = "" if self._keep_separator else self._separator | |||
| _good_splits_lengths = [] # cache the lengths of the splits | |||
| if splits: | |||
| _good_splits_lengths.extend(self._length_function(splits)) | |||
| return self._merge_splits(splits, _separator, _good_splits_lengths) | |||
| class LineType(TypedDict): | |||
| """Line type as typed dict.""" | |||
| metadata: dict[str, str] | |||
| content: str | |||
| class HeaderType(TypedDict): | |||
| """Header type as typed dict.""" | |||
| level: int | |||
| name: str | |||
| data: str | |||
| class MarkdownHeaderTextSplitter: | |||
| """Splitting markdown files based on specified headers.""" | |||
| def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False): | |||
| """Create a new MarkdownHeaderTextSplitter. | |||
| Args: | |||
| headers_to_split_on: Headers we want to track | |||
| return_each_line: Return each line w/ associated headers | |||
| """ | |||
| # Output line-by-line or aggregated into chunks w/ common headers | |||
| self.return_each_line = return_each_line | |||
| # Given the headers we want to split on, | |||
| # (e.g., "#, ##, etc") order by length | |||
| self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True) | |||
| def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: | |||
| """Combine lines with common metadata into chunks | |||
| Args: | |||
| lines: Line of text / associated header metadata | |||
| """ | |||
| aggregated_chunks: list[LineType] = [] | |||
| for line in lines: | |||
| if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]: | |||
| # If the last line in the aggregated list | |||
| # has the same metadata as the current line, | |||
| # append the current content to the last lines's content | |||
| aggregated_chunks[-1]["content"] += " \n" + line["content"] | |||
| else: | |||
| # Otherwise, append the current line to the aggregated list | |||
| aggregated_chunks.append(line) | |||
| return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] | |||
| def split_text(self, text: str) -> list[Document]: | |||
| """Split markdown file | |||
| Args: | |||
| text: Markdown file""" | |||
| # Split the input text by newline character ("\n"). | |||
| lines = text.split("\n") | |||
| # Final output | |||
| lines_with_metadata: list[LineType] = [] | |||
| # Content and metadata of the chunk currently being processed | |||
| current_content: list[str] = [] | |||
| current_metadata: dict[str, str] = {} | |||
| # Keep track of the nested header structure | |||
| # header_stack: List[Dict[str, Union[int, str]]] = [] | |||
| header_stack: list[HeaderType] = [] | |||
| initial_metadata: dict[str, str] = {} | |||
| for line in lines: | |||
| stripped_line = line.strip() | |||
| # Check each line against each of the header types (e.g., #, ##) | |||
| for sep, name in self.headers_to_split_on: | |||
| # Check if line starts with a header that we intend to split on | |||
| if stripped_line.startswith(sep) and ( | |||
| # Header with no text OR header is followed by space | |||
| # Both are valid conditions that sep is being used a header | |||
| len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " | |||
| ): | |||
| # Ensure we are tracking the header as metadata | |||
| if name is not None: | |||
| # Get the current header level | |||
| current_header_level = sep.count("#") | |||
| # Pop out headers of lower or same level from the stack | |||
| while header_stack and header_stack[-1]["level"] >= current_header_level: | |||
| # We have encountered a new header | |||
| # at the same or higher level | |||
| popped_header = header_stack.pop() | |||
| # Clear the metadata for the | |||
| # popped header in initial_metadata | |||
| if popped_header["name"] in initial_metadata: | |||
| initial_metadata.pop(popped_header["name"]) | |||
| # Push the current header to the stack | |||
| header: HeaderType = { | |||
| "level": current_header_level, | |||
| "name": name, | |||
| "data": stripped_line[len(sep) :].strip(), | |||
| } | |||
| header_stack.append(header) | |||
| # Update initial_metadata with the current header | |||
| initial_metadata[name] = header["data"] | |||
| # Add the previous line to the lines_with_metadata | |||
| # only if current_content is not empty | |||
| if current_content: | |||
| lines_with_metadata.append( | |||
| { | |||
| "content": "\n".join(current_content), | |||
| "metadata": current_metadata.copy(), | |||
| } | |||
| ) | |||
| current_content.clear() | |||
| break | |||
| else: | |||
| if stripped_line: | |||
| current_content.append(stripped_line) | |||
| elif current_content: | |||
| lines_with_metadata.append( | |||
| { | |||
| "content": "\n".join(current_content), | |||
| "metadata": current_metadata.copy(), | |||
| } | |||
| ) | |||
| current_content.clear() | |||
| current_metadata = initial_metadata.copy() | |||
| if current_content: | |||
| lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata}) | |||
| # lines_with_metadata has each line with associated header metadata | |||
| # aggregate these into chunks based on common metadata | |||
| if not self.return_each_line: | |||
| return self.aggregate_lines_to_chunks(lines_with_metadata) | |||
| else: | |||
| return [ | |||
| Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata | |||
| ] | |||
| # should be in newer Python versions (3.10+) | |||
| # @dataclass(frozen=True, kw_only=True, slots=True) | |||
| @dataclass(frozen=True) | |||
| class Tokenizer: | |||
| @@ -5,8 +5,11 @@ This package contains concrete implementations of the repository interfaces | |||
| defined in the core.workflow.repository package. | |||
| """ | |||
| from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError | |||
| from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository | |||
| __all__ = [ | |||
| "DifyCoreRepositoryFactory", | |||
| "RepositoryImportError", | |||
| "SQLAlchemyWorkflowNodeExecutionRepository", | |||
| ] | |||
| @@ -0,0 +1,224 @@ | |||
| """ | |||
| Repository factory for dynamically creating repository instances based on configuration. | |||
| This module provides a Django-like settings system for repository implementations, | |||
| allowing users to configure different repository backends through string paths. | |||
| """ | |||
| import importlib | |||
| import inspect | |||
| import logging | |||
| from typing import Protocol, Union | |||
| from sqlalchemy.engine import Engine | |||
| from sqlalchemy.orm import sessionmaker | |||
| from configs import dify_config | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from models import Account, EndUser | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| from models.workflow import WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| class RepositoryImportError(Exception): | |||
| """Raised when a repository implementation cannot be imported or instantiated.""" | |||
| pass | |||
| class DifyCoreRepositoryFactory: | |||
| """ | |||
| Factory for creating repository instances based on configuration. | |||
| This factory supports Django-like settings where repository implementations | |||
| are specified as module paths (e.g., 'module.submodule.ClassName'). | |||
| """ | |||
| @staticmethod | |||
| def _import_class(class_path: str) -> type: | |||
| """ | |||
| Import a class from a module path string. | |||
| Args: | |||
| class_path: Full module path to the class (e.g., 'module.submodule.ClassName') | |||
| Returns: | |||
| The imported class | |||
| Raises: | |||
| RepositoryImportError: If the class cannot be imported | |||
| """ | |||
| try: | |||
| module_path, class_name = class_path.rsplit(".", 1) | |||
| module = importlib.import_module(module_path) | |||
| repo_class = getattr(module, class_name) | |||
| assert isinstance(repo_class, type) | |||
| return repo_class | |||
| except (ValueError, ImportError, AttributeError) as e: | |||
| raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e | |||
| @staticmethod | |||
| def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore | |||
| """ | |||
| Validate that a class implements the expected repository interface. | |||
| Args: | |||
| repository_class: The class to validate | |||
| expected_interface: The expected interface/protocol | |||
| Raises: | |||
| RepositoryImportError: If the class doesn't implement the interface | |||
| """ | |||
| # Check if the class has all required methods from the protocol | |||
| required_methods = [ | |||
| method | |||
| for method in dir(expected_interface) | |||
| if not method.startswith("_") and callable(getattr(expected_interface, method, None)) | |||
| ] | |||
| missing_methods = [] | |||
| for method_name in required_methods: | |||
| if not hasattr(repository_class, method_name): | |||
| missing_methods.append(method_name) | |||
| if missing_methods: | |||
| raise RepositoryImportError( | |||
| f"Repository class '{repository_class.__name__}' does not implement required methods " | |||
| f"{missing_methods} from interface '{expected_interface.__name__}'" | |||
| ) | |||
| @staticmethod | |||
| def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None: | |||
| """ | |||
| Validate that a repository class constructor accepts required parameters. | |||
| Args: | |||
| repository_class: The class to validate | |||
| required_params: List of required parameter names | |||
| Raises: | |||
| RepositoryImportError: If the constructor doesn't accept required parameters | |||
| """ | |||
| try: | |||
| # MyPy may flag the line below with the following error: | |||
| # | |||
| # > Accessing "__init__" on an instance is unsound, since | |||
| # > instance.__init__ could be from an incompatible subclass. | |||
| # | |||
| # Despite this, we need to ensure that the constructor of `repository_class` | |||
| # has a compatible signature. | |||
| signature = inspect.signature(repository_class.__init__) # type: ignore[misc] | |||
| param_names = list(signature.parameters.keys()) | |||
| # Remove 'self' parameter | |||
| if "self" in param_names: | |||
| param_names.remove("self") | |||
| missing_params = [param for param in required_params if param not in param_names] | |||
| if missing_params: | |||
| raise RepositoryImportError( | |||
| f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: " | |||
| f"{missing_params}. Expected parameters: {required_params}" | |||
| ) | |||
| except Exception as e: | |||
| raise RepositoryImportError( | |||
| f"Failed to validate constructor signature for '{repository_class.__name__}': {e}" | |||
| ) from e | |||
| @classmethod | |||
| def create_workflow_execution_repository( | |||
| cls, | |||
| session_factory: Union[sessionmaker, Engine], | |||
| user: Union[Account, EndUser], | |||
| app_id: str, | |||
| triggered_from: WorkflowRunTriggeredFrom, | |||
| ) -> WorkflowExecutionRepository: | |||
| """ | |||
| Create a WorkflowExecutionRepository instance based on configuration. | |||
| Args: | |||
| session_factory: SQLAlchemy sessionmaker or engine | |||
| user: Account or EndUser object | |||
| app_id: Application ID | |||
| triggered_from: Source of the execution trigger | |||
| Returns: | |||
| Configured WorkflowExecutionRepository instance | |||
| Raises: | |||
| RepositoryImportError: If the configured repository cannot be created | |||
| """ | |||
| class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY | |||
| logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}") | |||
| try: | |||
| repository_class = cls._import_class(class_path) | |||
| cls._validate_repository_interface(repository_class, WorkflowExecutionRepository) | |||
| cls._validate_constructor_signature( | |||
| repository_class, ["session_factory", "user", "app_id", "triggered_from"] | |||
| ) | |||
| return repository_class( # type: ignore[no-any-return] | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=app_id, | |||
| triggered_from=triggered_from, | |||
| ) | |||
| except RepositoryImportError: | |||
| # Re-raise our custom errors as-is | |||
| raise | |||
| except Exception as e: | |||
| logger.exception("Failed to create WorkflowExecutionRepository") | |||
| raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e | |||
| @classmethod | |||
| def create_workflow_node_execution_repository( | |||
| cls, | |||
| session_factory: Union[sessionmaker, Engine], | |||
| user: Union[Account, EndUser], | |||
| app_id: str, | |||
| triggered_from: WorkflowNodeExecutionTriggeredFrom, | |||
| ) -> WorkflowNodeExecutionRepository: | |||
| """ | |||
| Create a WorkflowNodeExecutionRepository instance based on configuration. | |||
| Args: | |||
| session_factory: SQLAlchemy sessionmaker or engine | |||
| user: Account or EndUser object | |||
| app_id: Application ID | |||
| triggered_from: Source of the execution trigger | |||
| Returns: | |||
| Configured WorkflowNodeExecutionRepository instance | |||
| Raises: | |||
| RepositoryImportError: If the configured repository cannot be created | |||
| """ | |||
| class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY | |||
| logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}") | |||
| try: | |||
| repository_class = cls._import_class(class_path) | |||
| cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository) | |||
| cls._validate_constructor_signature( | |||
| repository_class, ["session_factory", "user", "app_id", "triggered_from"] | |||
| ) | |||
| return repository_class( # type: ignore[no-any-return] | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=app_id, | |||
| triggered_from=triggered_from, | |||
| ) | |||
| except RepositoryImportError: | |||
| # Re-raise our custom errors as-is | |||
| raise | |||
| except Exception as e: | |||
| logger.exception("Failed to create WorkflowNodeExecutionRepository") | |||
| raise RepositoryImportError( | |||
| f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}" | |||
| ) from e | |||
| @@ -16,6 +16,7 @@ from core.plugin.entities.parameters import ( | |||
| cast_parameter_value, | |||
| init_frontend_parameter, | |||
| ) | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.tools.entities.common_entities import I18nObject | |||
| from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY | |||
| @@ -179,6 +180,10 @@ class ToolInvokeMessage(BaseModel): | |||
| data: Mapping[str, Any] = Field(..., description="Detailed log data") | |||
| metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") | |||
| class RetrieverResourceMessage(BaseModel): | |||
| retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") | |||
| context: str = Field(..., description="context") | |||
| class MessageType(Enum): | |||
| TEXT = "text" | |||
| IMAGE = "image" | |||
| @@ -191,13 +196,22 @@ class ToolInvokeMessage(BaseModel): | |||
| FILE = "file" | |||
| LOG = "log" | |||
| BLOB_CHUNK = "blob_chunk" | |||
| RETRIEVER_RESOURCES = "retriever_resources" | |||
| type: MessageType = MessageType.TEXT | |||
| """ | |||
| plain text, image url or link url | |||
| """ | |||
| message: ( | |||
| JsonMessage | TextMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage | |||
| JsonMessage | |||
| | TextMessage | |||
| | BlobChunkMessage | |||
| | BlobMessage | |||
| | LogMessage | |||
| | FileMessage | |||
| | None | |||
| | VariableMessage | |||
| | RetrieverResourceMessage | |||
| ) | |||
| meta: dict[str, Any] | None = None | |||
| @@ -243,6 +257,7 @@ class ToolParameter(PluginParameter): | |||
| FILES = PluginParameterType.FILES.value | |||
| APP_SELECTOR = PluginParameterType.APP_SELECTOR.value | |||
| MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value | |||
| ANY = PluginParameterType.ANY.value | |||
| DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value | |||
| # MCP object and array type parameters | |||
| @@ -1,5 +1,4 @@ | |||
| import re | |||
| import uuid | |||
| from json import dumps as json_dumps | |||
| from json import loads as json_loads | |||
| from json.decoder import JSONDecodeError | |||
| @@ -154,7 +153,7 @@ class ApiBasedToolSchemaParser: | |||
| # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ | |||
| path = re.sub(r"[^a-zA-Z0-9_-]", "", path) | |||
| if not path: | |||
| path = str(uuid.uuid4()) | |||
| path = "<root>" | |||
| interface["operation"]["operationId"] = f"{path}_{interface['method']}" | |||
| @@ -1,9 +1,9 @@ | |||
| import json | |||
| import sys | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any | |||
| from typing import Annotated, Any, TypeAlias | |||
| from pydantic import BaseModel, ConfigDict, field_validator | |||
| from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator | |||
| from core.file import File | |||
| @@ -11,6 +11,11 @@ from .types import SegmentType | |||
| class Segment(BaseModel): | |||
| """Segment is runtime type used during the execution of workflow. | |||
| Note: this class is abstract, you should use subclasses of this class instead. | |||
| """ | |||
| model_config = ConfigDict(frozen=True) | |||
| value_type: SegmentType | |||
| @@ -73,7 +78,7 @@ class StringSegment(Segment): | |||
| class FloatSegment(Segment): | |||
| value_type: SegmentType = SegmentType.NUMBER | |||
| value_type: SegmentType = SegmentType.FLOAT | |||
| value: float | |||
| # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. | |||
| # The following tests cannot pass. | |||
| @@ -92,7 +97,7 @@ class FloatSegment(Segment): | |||
| class IntegerSegment(Segment): | |||
| value_type: SegmentType = SegmentType.NUMBER | |||
| value_type: SegmentType = SegmentType.INTEGER | |||
| value: int | |||
| @@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment): | |||
| @property | |||
| def text(self) -> str: | |||
| return "" | |||
| def get_segment_discriminator(v: Any) -> SegmentType | None: | |||
| if isinstance(v, Segment): | |||
| return v.value_type | |||
| elif isinstance(v, dict): | |||
| value_type = v.get("value_type") | |||
| if value_type is None: | |||
| return None | |||
| try: | |||
| seg_type = SegmentType(value_type) | |||
| except ValueError: | |||
| return None | |||
| return seg_type | |||
| else: | |||
| # return None if the discriminator value isn't found | |||
| return None | |||
| # The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. | |||
| # Use `Segment` for type hinting when serialization is not required. | |||
| # | |||
| # Note: | |||
| # - All variants in `SegmentUnion` must inherit from the `Segment` class. | |||
| # - The union must include all non-abstract subclasses of `Segment`, except: | |||
| # - `SegmentGroup`, which is not added to the variable pool. | |||
| # - `Variable` and its subclasses, which are handled by `VariableUnion`. | |||
| SegmentUnion: TypeAlias = Annotated[ | |||
| ( | |||
| Annotated[NoneSegment, Tag(SegmentType.NONE)] | |||
| | Annotated[StringSegment, Tag(SegmentType.STRING)] | |||
| | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] | |||
| | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] | |||
| | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] | |||
| | Annotated[FileSegment, Tag(SegmentType.FILE)] | |||
| | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] | |||
| | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] | |||
| | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] | |||
| | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] | |||
| | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] | |||
| ), | |||
| Discriminator(get_segment_discriminator), | |||
| ] | |||
| @@ -1,8 +1,27 @@ | |||
| from collections.abc import Mapping | |||
| from enum import StrEnum | |||
| from typing import Any, Optional | |||
| from core.file.models import File | |||
| class ArrayValidation(StrEnum): | |||
| """Strategy for validating array elements""" | |||
| # Skip element validation (only check array container) | |||
| NONE = "none" | |||
| # Validate the first element (if array is non-empty) | |||
| FIRST = "first" | |||
| # Validate all elements in the array. | |||
| ALL = "all" | |||
| class SegmentType(StrEnum): | |||
| NUMBER = "number" | |||
| INTEGER = "integer" | |||
| FLOAT = "float" | |||
| STRING = "string" | |||
| OBJECT = "object" | |||
| SECRET = "secret" | |||
| @@ -19,16 +38,139 @@ class SegmentType(StrEnum): | |||
| GROUP = "group" | |||
| def is_array_type(self): | |||
| def is_array_type(self) -> bool: | |||
| return self in _ARRAY_TYPES | |||
| @classmethod | |||
| def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]: | |||
| """ | |||
| Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. | |||
| Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. | |||
| For example, this may occur if the input is a generic Python object of type `object`. | |||
| """ | |||
| if isinstance(value, list): | |||
| elem_types: set[SegmentType] = set() | |||
| for i in value: | |||
| segment_type = cls.infer_segment_type(i) | |||
| if segment_type is None: | |||
| return None | |||
| elem_types.add(segment_type) | |||
| if len(elem_types) != 1: | |||
| if elem_types.issubset(_NUMERICAL_TYPES): | |||
| return SegmentType.ARRAY_NUMBER | |||
| return SegmentType.ARRAY_ANY | |||
| elif all(i.is_array_type() for i in elem_types): | |||
| return SegmentType.ARRAY_ANY | |||
| match elem_types.pop(): | |||
| case SegmentType.STRING: | |||
| return SegmentType.ARRAY_STRING | |||
| case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: | |||
| return SegmentType.ARRAY_NUMBER | |||
| case SegmentType.OBJECT: | |||
| return SegmentType.ARRAY_OBJECT | |||
| case SegmentType.FILE: | |||
| return SegmentType.ARRAY_FILE | |||
| case SegmentType.NONE: | |||
| return SegmentType.ARRAY_ANY | |||
| case _: | |||
| # This should be unreachable. | |||
| raise ValueError(f"not supported value {value}") | |||
| if value is None: | |||
| return SegmentType.NONE | |||
| elif isinstance(value, int) and not isinstance(value, bool): | |||
| return SegmentType.INTEGER | |||
| elif isinstance(value, float): | |||
| return SegmentType.FLOAT | |||
| elif isinstance(value, str): | |||
| return SegmentType.STRING | |||
| elif isinstance(value, dict): | |||
| return SegmentType.OBJECT | |||
| elif isinstance(value, File): | |||
| return SegmentType.FILE | |||
| else: | |||
| return None | |||
| def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool: | |||
| if not isinstance(value, list): | |||
| return False | |||
| # Skip element validation if array is empty | |||
| if len(value) == 0: | |||
| return True | |||
| if self == SegmentType.ARRAY_ANY: | |||
| return True | |||
| element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self] | |||
| if array_validation == ArrayValidation.NONE: | |||
| return True | |||
| elif array_validation == ArrayValidation.FIRST: | |||
| return element_type.is_valid(value[0]) | |||
| else: | |||
| return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value) | |||
| def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool: | |||
| """ | |||
| Check if a value matches the segment type. | |||
| Users of `SegmentType` should call this method, instead of using | |||
| `isinstance` manually. | |||
| Args: | |||
| value: The value to validate | |||
| array_validation: Validation strategy for array types (ignored for non-array types) | |||
| Returns: | |||
| True if the value matches the type under the given validation strategy | |||
| """ | |||
| if self.is_array_type(): | |||
| return self._validate_array(value, array_validation) | |||
| elif self == SegmentType.NUMBER: | |||
| return isinstance(value, (int, float)) | |||
| elif self == SegmentType.STRING: | |||
| return isinstance(value, str) | |||
| elif self == SegmentType.OBJECT: | |||
| return isinstance(value, dict) | |||
| elif self == SegmentType.SECRET: | |||
| return isinstance(value, str) | |||
| elif self == SegmentType.FILE: | |||
| return isinstance(value, File) | |||
| elif self == SegmentType.NONE: | |||
| return value is None | |||
| else: | |||
| raise AssertionError("this statement should be unreachable.") | |||
| def exposed_type(self) -> "SegmentType": | |||
| """Returns the type exposed to the frontend. | |||
| The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. | |||
| """ | |||
| if self in (SegmentType.INTEGER, SegmentType.FLOAT): | |||
| return SegmentType.NUMBER | |||
| return self | |||
| _ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { | |||
| # ARRAY_ANY does not have correpond element type. | |||
| SegmentType.ARRAY_STRING: SegmentType.STRING, | |||
| SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, | |||
| SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, | |||
| SegmentType.ARRAY_FILE: SegmentType.FILE, | |||
| } | |||
| _ARRAY_TYPES = frozenset( | |||
| [ | |||
| list(_ARRAY_ELEMENT_TYPES_MAPPING.keys()) | |||
| + [ | |||
| SegmentType.ARRAY_ANY, | |||
| SegmentType.ARRAY_STRING, | |||
| SegmentType.ARRAY_NUMBER, | |||
| SegmentType.ARRAY_OBJECT, | |||
| SegmentType.ARRAY_FILE, | |||
| ] | |||
| ) | |||
| _NUMERICAL_TYPES = frozenset( | |||
| [ | |||
| SegmentType.NUMBER, | |||
| SegmentType.INTEGER, | |||
| SegmentType.FLOAT, | |||
| ] | |||
| ) | |||
| @@ -3,6 +3,10 @@ from typing import Any, cast | |||
| from uuid import uuid4 | |||
| from pydantic import BaseModel, Field | |||
| from typing import Annotated, TypeAlias, cast | |||
| from uuid import uuid4 | |||
| from pydantic import Discriminator, Field, Tag | |||
| from core.helper import encrypter | |||
| @@ -20,6 +24,7 @@ from .segments import ( | |||
| ObjectSegment, | |||
| Segment, | |||
| StringSegment, | |||
| get_segment_discriminator, | |||
| ) | |||
| from .types import SegmentType | |||
| @@ -27,6 +32,10 @@ from .types import SegmentType | |||
| class Variable(Segment): | |||
| """ | |||
| A variable is a segment that has a name. | |||
| It is mainly used to store segments and their selector in VariablePool. | |||
| Note: this class is abstract, you should use subclasses of this class instead. | |||
| """ | |||
| id: str = Field( | |||
| @@ -122,3 +131,26 @@ class RAGPipelineVariable(BaseModel): | |||
| class RAGPipelineVariableInput(BaseModel): | |||
| variable: RAGPipelineVariable | |||
| value: Any | |||
| # The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. | |||
| # Use `Variable` for type hinting when serialization is not required. | |||
| # | |||
| # Note: | |||
| # - All variants in `VariableUnion` must inherit from the `Variable` class. | |||
| # - The union must include all non-abstract subclasses of `Segment`, except: | |||
| VariableUnion: TypeAlias = Annotated[ | |||
| ( | |||
| Annotated[NoneVariable, Tag(SegmentType.NONE)] | |||
| | Annotated[StringVariable, Tag(SegmentType.STRING)] | |||
| | Annotated[FloatVariable, Tag(SegmentType.FLOAT)] | |||
| | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] | |||
| | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] | |||
| | Annotated[FileVariable, Tag(SegmentType.FILE)] | |||
| | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] | |||
| | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] | |||
| | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] | |||
| | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] | |||
| | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] | |||
| | Annotated[SecretVariable, Tag(SegmentType.SECRET)] | |||
| ), | |||
| Discriminator(get_segment_discriminator), | |||
| ] | |||
| @@ -1,7 +1,7 @@ | |||
| import re | |||
| from collections import defaultdict | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Union | |||
| from typing import Annotated, Any, Union, cast | |||
| from pydantic import BaseModel, Field | |||
| @@ -17,6 +17,9 @@ from core.workflow.constants import ( | |||
| SYSTEM_VARIABLE_NODE_ID, | |||
| ) | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.variables.variables import VariableUnion | |||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | |||
| from core.workflow.system_variable import SystemVariable | |||
| from factories import variable_factory | |||
| VariableValue = Union[str, int, float, dict, list, File] | |||
| @@ -29,24 +32,24 @@ class VariablePool(BaseModel): | |||
| # The first element of the selector is the node id, it's the first-level key in the dictionary. | |||
| # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the | |||
| # elements of the selector except the first one. | |||
| variable_dictionary: dict[str, dict[int, Segment]] = Field( | |||
| variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field( | |||
| description="Variables mapping", | |||
| default=defaultdict(dict), | |||
| ) | |||
| # TODO: This user inputs is not used for pool. | |||
| # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. | |||
| user_inputs: Mapping[str, Any] = Field( | |||
| description="User inputs", | |||
| default_factory=dict, | |||
| ) | |||
| system_variables: Mapping[SystemVariableKey, Any] = Field( | |||
| system_variables: SystemVariable = Field( | |||
| description="System variables", | |||
| default_factory=dict, | |||
| ) | |||
| environment_variables: Sequence[Variable] = Field( | |||
| environment_variables: Sequence[VariableUnion] = Field( | |||
| description="Environment variables.", | |||
| default_factory=list, | |||
| ) | |||
| conversation_variables: Sequence[Variable] = Field( | |||
| conversation_variables: Sequence[VariableUnion] = Field( | |||
| description="Conversation variables.", | |||
| default_factory=list, | |||
| ) | |||
| @@ -56,8 +59,8 @@ class VariablePool(BaseModel): | |||
| ) | |||
| def model_post_init(self, context: Any, /) -> None: | |||
| for key, value in self.system_variables.items(): | |||
| self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) | |||
| # Create a mapping from field names to SystemVariableKey enum values | |||
| self._add_system_variables(self.system_variables) | |||
| # Add environment variables to the variable pool | |||
| for var in self.environment_variables: | |||
| self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) | |||
| @@ -96,8 +99,22 @@ class VariablePool(BaseModel): | |||
| segment = variable_factory.build_segment(value) | |||
| variable = variable_factory.segment_to_variable(segment=segment, selector=selector) | |||
| hash_key = hash(tuple(selector[1:])) | |||
| self.variable_dictionary[selector[0]][hash_key] = variable | |||
| key, hash_key = self._selector_to_keys(selector) | |||
| # Based on the definition of `VariableUnion`, | |||
| # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. | |||
| self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable) | |||
| @classmethod | |||
| def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]: | |||
| return selector[0], hash(tuple(selector[1:])) | |||
| def _has(self, selector: Sequence[str]) -> bool: | |||
| key, hash_key = self._selector_to_keys(selector) | |||
| if key not in self.variable_dictionary: | |||
| return False | |||
| if hash_key not in self.variable_dictionary[key]: | |||
| return False | |||
| return True | |||
| def get(self, selector: Sequence[str], /) -> Segment | None: | |||
| """ | |||
| @@ -115,8 +132,8 @@ class VariablePool(BaseModel): | |||
| if len(selector) < MIN_SELECTORS_LENGTH: | |||
| return None | |||
| hash_key = hash(tuple(selector[1:])) | |||
| value = self.variable_dictionary[selector[0]].get(hash_key) | |||
| key, hash_key = self._selector_to_keys(selector) | |||
| value: Segment | None = self.variable_dictionary[key].get(hash_key) | |||
| if value is None: | |||
| selector, attr = selector[:-1], selector[-1] | |||
| @@ -149,8 +166,8 @@ class VariablePool(BaseModel): | |||
| if len(selector) == 1: | |||
| self.variable_dictionary[selector[0]] = {} | |||
| return | |||
| hash_key = hash(tuple(selector[1:])) | |||
| self.variable_dictionary[selector[0]].pop(hash_key, None) | |||
| key, hash_key = self._selector_to_keys(selector) | |||
| self.variable_dictionary[key].pop(hash_key, None) | |||
| def convert_template(self, template: str, /): | |||
| parts = VARIABLE_PATTERN.split(template) | |||
| @@ -167,3 +184,20 @@ class VariablePool(BaseModel): | |||
| if isinstance(segment, FileSegment): | |||
| return segment | |||
| return None | |||
| def _add_system_variables(self, system_variable: SystemVariable): | |||
| sys_var_mapping = system_variable.to_dict() | |||
| for key, value in sys_var_mapping.items(): | |||
| if value is None: | |||
| continue | |||
| selector = (SYSTEM_VARIABLE_NODE_ID, key) | |||
| # If the system variable already exists, do not add it again. | |||
| # This ensures that we can keep the id of the system variables intact. | |||
| if self._has(selector): | |||
| continue | |||
| self.add(selector, value) # type: ignore | |||
| @classmethod | |||
| def empty(cls) -> "VariablePool": | |||
| """Create an empty variable pool.""" | |||
| return cls(system_variables=SystemVariable.empty()) | |||
| @@ -1,79 +0,0 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.nodes.base import BaseIterationState, BaseLoopState, BaseNode | |||
| from models.enums import UserFrom | |||
| from models.workflow import Workflow, WorkflowType | |||
| from .node_entities import NodeRunResult | |||
| from .variable_pool import VariablePool | |||
| class WorkflowNodeAndResult: | |||
| node: BaseNode | |||
| result: Optional[NodeRunResult] = None | |||
| def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None): | |||
| self.node = node | |||
| self.result = result | |||
| class WorkflowRunState: | |||
| tenant_id: str | |||
| app_id: str | |||
| workflow_id: str | |||
| workflow_type: WorkflowType | |||
| user_id: str | |||
| user_from: UserFrom | |||
| invoke_from: InvokeFrom | |||
| workflow_call_depth: int | |||
| start_at: float | |||
| variable_pool: VariablePool | |||
| total_tokens: int = 0 | |||
| workflow_nodes_and_results: list[WorkflowNodeAndResult] | |||
| class NodeRun(BaseModel): | |||
| node_id: str | |||
| iteration_node_id: str | |||
| loop_node_id: str | |||
| workflow_node_runs: list[NodeRun] | |||
| workflow_node_steps: int | |||
| current_iteration_state: Optional[BaseIterationState] | |||
| current_loop_state: Optional[BaseLoopState] | |||
| def __init__( | |||
| self, | |||
| workflow: Workflow, | |||
| start_at: float, | |||
| variable_pool: VariablePool, | |||
| user_id: str, | |||
| user_from: UserFrom, | |||
| invoke_from: InvokeFrom, | |||
| workflow_call_depth: int, | |||
| ): | |||
| self.workflow_id = workflow.id | |||
| self.tenant_id = workflow.tenant_id | |||
| self.app_id = workflow.app_id | |||
| self.workflow_type = WorkflowType.value_of(workflow.type) | |||
| self.user_id = user_id | |||
| self.user_from = user_from | |||
| self.invoke_from = invoke_from | |||
| self.workflow_call_depth = workflow_call_depth | |||
| self.start_at = start_at | |||
| self.variable_pool = variable_pool | |||
| self.total_tokens = 0 | |||
| self.workflow_node_steps = 1 | |||
| self.workflow_node_runs = [] | |||
| self.current_iteration_state = None | |||
| self.current_loop_state = None | |||
| @@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel): | |||
| """total tokens""" | |||
| llm_usage: LLMUsage = LLMUsage.empty_usage() | |||
| """llm usage info""" | |||
| # The `outputs` field stores the final output values generated by executing workflows or chatflows. | |||
| # | |||
| # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent | |||
| # after a serialization and deserialization round trip. | |||
| outputs: dict[str, Any] = {} | |||
| """outputs""" | |||
| node_run_steps: int = 0 | |||
| """node run steps""" | |||
| @@ -1,4 +1,5 @@ | |||
| from collections.abc import Mapping, Sequence | |||
| from decimal import Decimal | |||
| from typing import Any, Optional | |||
| from configs import dify_config | |||
| @@ -114,8 +115,10 @@ class CodeNode(BaseNode[CodeNodeData]): | |||
| ) | |||
| if isinstance(value, float): | |||
| decimal_value = Decimal(str(value)).normalize() | |||
| precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] | |||
| # raise error if precision is too high | |||
| if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: | |||
| if precision > dify_config.CODE_MAX_PRECISION: | |||
| raise OutputValidationError( | |||
| f"Output variable `{variable}` has too high precision," | |||
| f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." | |||
| @@ -521,18 +521,52 @@ class IterationNode(BaseNode[IterationNodeData]): | |||
| ) | |||
| return | |||
| elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: | |||
| yield IterationRunFailedEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| outputs={"output": None}, | |||
| steps=len(iterator_list_value), | |||
| metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, | |||
| error=event.error, | |||
| yield NodeInIterationFailedEvent( | |||
| **metadata_event.model_dump(), | |||
| ) | |||
| outputs[current_index] = None | |||
| # clean nodes resources | |||
| for node_id in iteration_graph.node_ids: | |||
| variable_pool.remove([node_id]) | |||
| # iteration run failed | |||
| if self.node_data.is_parallel: | |||
| yield IterationRunFailedEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| parallel_mode_run_id=parallel_mode_run_id, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| outputs={"output": outputs}, | |||
| steps=len(iterator_list_value), | |||
| metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, | |||
| error=event.error, | |||
| ) | |||
| else: | |||
| yield IterationRunFailedEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| outputs={"output": outputs}, | |||
| steps=len(iterator_list_value), | |||
| metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, | |||
| error=event.error, | |||
| ) | |||
| # stop the iterator | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=event.error, | |||
| ) | |||
| ) | |||
| return | |||
| yield metadata_event | |||
| current_output_segment = variable_pool.get(self.node_data.output_selector) | |||
| @@ -144,6 +144,8 @@ class KnowledgeRetrievalNode(LLMNode): | |||
| error=str(e), | |||
| error_type=type(e).__name__, | |||
| ) | |||
| finally: | |||
| db.session.close() | |||
| def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: | |||
| available_datasets = [] | |||
| @@ -171,6 +173,9 @@ class KnowledgeRetrievalNode(LLMNode): | |||
| .all() | |||
| ) | |||
| # avoid blocking at retrieval | |||
| db.session.close() | |||
| for dataset in results: | |||
| # pass if dataset is not available | |||
| if not dataset: | |||
| @@ -1,11 +1,29 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any, Literal, Optional | |||
| from typing import Annotated, Any, Literal, Optional | |||
| from pydantic import BaseModel, Field | |||
| from pydantic import AfterValidator, BaseModel, Field | |||
| from core.variables.types import SegmentType | |||
| from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData | |||
| from core.workflow.utils.condition.entities import Condition | |||
| _VALID_VAR_TYPE = frozenset( | |||
| [ | |||
| SegmentType.STRING, | |||
| SegmentType.NUMBER, | |||
| SegmentType.OBJECT, | |||
| SegmentType.ARRAY_STRING, | |||
| SegmentType.ARRAY_NUMBER, | |||
| SegmentType.ARRAY_OBJECT, | |||
| ] | |||
| ) | |||
| def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: | |||
| if seg_type not in _VALID_VAR_TYPE: | |||
| raise ValueError(...) | |||
| return seg_type | |||
| class LoopVariableData(BaseModel): | |||
| """ | |||
| @@ -13,7 +31,7 @@ class LoopVariableData(BaseModel): | |||
| """ | |||
| label: str | |||
| var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] | |||
| var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] | |||
| value_type: Literal["variable", "constant"] | |||
| value: Optional[Any | list[str]] = None | |||
| @@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast | |||
| from configs import dify_config | |||
| from core.variables import ( | |||
| ArrayNumberSegment, | |||
| ArrayObjectSegment, | |||
| ArrayStringSegment, | |||
| IntegerSegment, | |||
| ObjectSegment, | |||
| Segment, | |||
| SegmentType, | |||
| StringSegment, | |||
| ) | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| @@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.event import NodeEvent, RunCompletedEvent | |||
| from core.workflow.nodes.loop.entities import LoopNodeData | |||
| from core.workflow.utils.condition.processor import ConditionProcessor | |||
| from factories.variable_factory import TypeMismatchError, build_segment_with_type | |||
| if TYPE_CHECKING: | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| @@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]): | |||
| return variable_mapping | |||
| @staticmethod | |||
| def _get_segment_for_constant(var_type: str, value: Any) -> Segment: | |||
| def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment: | |||
| """Get the appropriate segment type for a constant value.""" | |||
| segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = { | |||
| "string": (StringSegment, SegmentType.STRING), | |||
| "number": (IntegerSegment, SegmentType.NUMBER), | |||
| "object": (ObjectSegment, SegmentType.OBJECT), | |||
| "array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING), | |||
| "array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER), | |||
| "array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT), | |||
| } | |||
| if var_type in ["array[string]", "array[number]", "array[object]"]: | |||
| if value: | |||
| if value and isinstance(value, str): | |||
| value = json.loads(value) | |||
| else: | |||
| value = [] | |||
| segment_info = segment_mapping.get(var_type) | |||
| if not segment_info: | |||
| raise ValueError(f"Invalid variable type: {var_type}") | |||
| segment_class, value_type = segment_info | |||
| return segment_class(value=value, value_type=value_type) | |||
| try: | |||
| return build_segment_with_type(var_type, value) | |||
| except TypeMismatchError as type_exc: | |||
| # Attempt to parse the value as a JSON-encoded string, if applicable. | |||
| if not isinstance(value, str): | |||
| raise | |||
| try: | |||
| value = json.loads(value) | |||
| except ValueError: | |||
| raise type_exc | |||
| return build_segment_with_type(var_type, value) | |||
| @@ -16,7 +16,7 @@ class StartNode(BaseNode[StartNodeData]): | |||
| def _run(self) -> NodeRunResult: | |||
| node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) | |||
| system_inputs = self.graph_runtime_state.variable_pool.system_variables | |||
| system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() | |||
| # TODO: System variables should be directly accessible, no need for special handling | |||
| # Set system variables as node outputs. | |||
| @@ -22,7 +22,7 @@ from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.event import AgentLogEvent | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent | |||
| from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| @@ -373,6 +373,12 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| agent_logs.append(agent_log) | |||
| yield agent_log | |||
| elif message.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES: | |||
| assert isinstance(message.message, ToolInvokeMessage.RetrieverResourceMessage) | |||
| yield RunRetrieverResourceEvent( | |||
| retriever_resources=message.message.retriever_resources, | |||
| context=message.message.context, | |||
| ) | |||
| # Add agent_logs to outputs['json'] to ensure frontend can access thinking process | |||
| json_output: list[dict[str, Any]] = [] | |||
| @@ -130,6 +130,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): | |||
| def get_zero_value(t: SegmentType): | |||
| # TODO(QuantumGhost): this should be a method of `SegmentType`. | |||
| match t: | |||
| case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: | |||
| return variable_factory.build_segment([]) | |||
| @@ -137,6 +138,10 @@ def get_zero_value(t: SegmentType): | |||
| return variable_factory.build_segment({}) | |||
| case SegmentType.STRING: | |||
| return variable_factory.build_segment("") | |||
| case SegmentType.INTEGER: | |||
| return variable_factory.build_segment(0) | |||
| case SegmentType.FLOAT: | |||
| return variable_factory.build_segment(0.0) | |||
| case SegmentType.NUMBER: | |||
| return variable_factory.build_segment(0) | |||
| case _: | |||