Browse Source

Introduce Plugins (#13836)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: -LAN- <laipz8200@outlook.com>
Signed-off-by: xhe <xw897002528@gmail.com>
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: takatost <takatost@gmail.com>
Co-authored-by: kurokobo <kuro664@gmail.com>
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: AkaraChen <akarachen@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com>
Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com>
Co-authored-by: NFish <douxc512@gmail.com>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: Novice <857526207@qq.com>
Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com>
Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com>
Co-authored-by: eux <euxuuu@gmail.com>
Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com>
Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com>
Co-authored-by: lotsik <lotsik@mail.ru>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com>
Co-authored-by: CN-P5 <heibai2006@gmail.com>
Co-authored-by: CN-P5 <heibai2006@qq.com>
Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com>
Co-authored-by: yihong <zouzou0208@gmail.com>
Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Boris Feld <lothiraldan@gmail.com>
Co-authored-by: mbo <himabo@gmail.com>
Co-authored-by: mabo <mabo@aeyes.ai>
Co-authored-by: Warren Chen <warren.chen830@gmail.com>
Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com>
Co-authored-by: jiandanfeng <chenjh3@wangsu.com>
Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com>
Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com>
Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com>
Co-authored-by: Xu Song <xusong.vip@gmail.com>
Co-authored-by: rayshaw001 <396301947@163.com>
Co-authored-by: Ding Jiatong <dingjiatong@gmail.com>
Co-authored-by: Bowen Liang <liangbowen@gf.com.cn>
Co-authored-by: JasonVV <jasonwangiii@outlook.com>
Co-authored-by: le0zh <newlight@qq.com>
Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com>
Co-authored-by: k-zaku <zaku99@outlook.jp>
Co-authored-by: luckylhb90 <luckylhb90@gmail.com>
Co-authored-by: hobo.l <hobo.l@binance.com>
Co-authored-by: jiangbo721 <365065261@qq.com>
Co-authored-by: 刘江波 <jiangbo721@163.com>
Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com>
Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: sino <sino2322@gmail.com>
Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com>
Co-authored-by: lowell <lowell.hu@zkteco.in>
Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com>
Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com>
Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com>
Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com>
Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com>
Co-authored-by: Jason <ggbbddjm@gmail.com>
Co-authored-by: Xin Zhang <sjhpzx@gmail.com>
Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com>
Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com>
Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com>
Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com>
Co-authored-by: Yingchun Lai <laiyingchun@apache.org>
Co-authored-by: Hash Brown <hi@xzd.me>
Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com>
Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com>
Co-authored-by: aplio <ryo.091219@gmail.com>
Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com>
Co-authored-by: Nam Vu <zuzoovn@gmail.com>
Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com>
Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com>
Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com>
Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com>
Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp>
Co-authored-by: HQidea <HQidea@users.noreply.github.com>
Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com>
Co-authored-by: xhe <xw897002528@gmail.com>
Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com>
Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com>
Co-authored-by: engchina <12236799+engchina@users.noreply.github.com>
Co-authored-by: engchina <atjapan2015@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: 呆萌闷油瓶 <253605712@qq.com>
Co-authored-by: Kemal <kemalmeler@outlook.com>
Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com>
Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com>
Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com>
Co-authored-by: steven <sunzwj@digitalchina.com>
Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com>
Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com>
Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com>
Co-authored-by: 胡春东 <gycm520@gmail.com>
Co-authored-by: Junjie.M <118170653@qq.com>
Co-authored-by: MuYu <mr.muzea@gmail.com>
Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com>
Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com>
Co-authored-by: Fei He <droxer.he@gmail.com>
Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com>
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
Co-authored-by: douxc <7553076+douxc@users.noreply.github.com>
Co-authored-by: liuzhenghua <1090179900@qq.com>
Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com>
Co-authored-by: AugNSo <song.tiankai@icloud.com>
Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com>
Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com>
Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com>
Co-authored-by: Hundredwz <1808096180@qq.com>
Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
tags/1.0.0
Yeuoly 8 months ago
parent
commit
403e2d58b9
No account linked to committer's email address
100 changed files with 3015 additions and 762 deletions
  1. 3
    2
      .devcontainer/post_create_command.sh
  2. 0
    6
      .github/workflows/api-tests.yml
  3. 1
    0
      .github/workflows/db-migration-test.yml
  4. 8
    2
      .github/workflows/style.yml
  5. 3
    3
      .github/workflows/tool-test-sdks.yaml
  6. 2
    2
      .github/workflows/translate-i18n-base-on-english.yml
  7. 3
    3
      .github/workflows/web-tests.yml
  8. 7
    0
      .gitignore
  9. 5
    0
      api/.dockerignore
  10. 16
    1
      api/.env.example
  11. 4
    0
      api/Dockerfile
  12. 70
    3
      api/commands.py
  13. 60
    0
      api/configs/feature/__init__.py
  14. 1
    1
      api/configs/packaging/__init__.py
  15. 10
    0
      api/contexts/__init__.py
  16. 14
    2
      api/controllers/console/__init__.py
  17. 23
    7
      api/controllers/console/admin.py
  18. 12
    1
      api/controllers/console/apikey.py
  19. 20
    1
      api/controllers/console/app/app_import.py
  20. 32
    27
      api/controllers/console/app/site.py
  21. 39
    5
      api/controllers/console/app/workflow.py
  22. 7
    3
      api/controllers/console/auth/forgot_password.py
  23. 4
    1
      api/controllers/console/auth/oauth.py
  24. 63
    49
      api/controllers/console/datasets/data_source.py
  25. 3
    0
      api/controllers/console/datasets/datasets.py
  26. 19
    3
      api/controllers/console/datasets/datasets_document.py
  27. 8
    1
      api/controllers/console/init_validate.py
  28. 4
    3
      api/controllers/console/setup.py
  29. 56
    0
      api/controllers/console/workspace/__init__.py
  30. 36
    0
      api/controllers/console/workspace/agent_providers.py
  31. 205
    0
      api/controllers/console/workspace/endpoint.py
  32. 2
    2
      api/controllers/console/workspace/load_balancing_config.py
  33. 10
    45
      api/controllers/console/workspace/model_providers.py
  34. 7
    7
      api/controllers/console/workspace/models.py
  35. 475
    0
      api/controllers/console/workspace/plugin.py
  36. 108
    52
      api/controllers/console/workspace/tool_providers.py
  37. 7
    2
      api/controllers/console/wraps.py
  38. 1
    1
      api/controllers/files/__init__.py
  39. 69
    0
      api/controllers/files/upload.py
  40. 1
    0
      api/controllers/inner_api/__init__.py
  41. 0
    0
      api/controllers/inner_api/plugin/__init__.py
  42. 293
    0
      api/controllers/inner_api/plugin/plugin.py
  43. 116
    0
      api/controllers/inner_api/plugin/wraps.py
  44. 3
    3
      api/controllers/inner_api/workspace/workspace.py
  45. 19
    3
      api/controllers/inner_api/wraps.py
  46. 44
    70
      api/core/agent/base_agent_runner.py
  47. 31
    43
      api/core/agent/cot_agent_runner.py
  48. 4
    2
      api/core/agent/cot_chat_agent_runner.py
  49. 16
    5
      api/core/agent/entities.py
  50. 15
    24
      api/core/agent/fc_agent_runner.py
  51. 89
    0
      api/core/agent/plugin_entities.py
  52. 42
    0
      api/core/agent/strategy/base.py
  53. 59
    0
      api/core/agent/strategy/plugin.py
  54. 7
    6
      api/core/app/app_config/easy_ui_based_app/model_config/converter.py
  55. 12
    2
      api/core/app/app_config/easy_ui_based_app/model_config/manager.py
  56. 23
    13
      api/core/app/apps/advanced_chat/app_generator.py
  57. 2
    2
      api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
  58. 1
    1
      api/core/app/apps/advanced_chat/app_runner.py
  59. 4
    5
      api/core/app/apps/advanced_chat/generate_response_converter.py
  60. 9
    2
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  61. 12
    6
      api/core/app/apps/agent_chat/app_generator.py
  62. 16
    99
      api/core/app/apps/agent_chat/app_runner.py
  63. 9
    11
      api/core/app/apps/agent_chat/generate_response_converter.py
  64. 8
    18
      api/core/app/apps/base_app_generate_response_converter.py
  65. 21
    2
      api/core/app/apps/base_app_generator.py
  66. 2
    2
      api/core/app/apps/base_app_queue_manager.py
  67. 3
    3
      api/core/app/apps/chat/app_generator.py
  68. 7
    9
      api/core/app/apps/chat/generate_response_converter.py
  69. 5
    5
      api/core/app/apps/completion/app_generator.py
  70. 7
    9
      api/core/app/apps/completion/generate_response_converter.py
  71. 33
    17
      api/core/app/apps/workflow/app_generator.py
  72. 7
    9
      api/core/app/apps/workflow/generate_response_converter.py
  73. 8
    1
      api/core/app/apps/workflow/generate_task_pipeline.py
  74. 16
    0
      api/core/app/apps/workflow_app_runner.py
  75. 1
    1
      api/core/app/entities/app_invoke_entities.py
  76. 19
    1
      api/core/app/entities/queue_entities.py
  77. 26
    0
      api/core/app/entities/task_entities.py
  78. 3
    1
      api/core/app/features/hosting_moderation/hosting_moderation.py
  79. 3
    1
      api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
  80. 26
    3
      api/core/app/task_pipeline/workflow_cycle_manage.py
  81. 2
    2
      api/core/callback_handler/agent_tool_callback_handler.py
  82. 22
    1
      api/core/callback_handler/workflow_tool_callback_handler.py
  83. 1
    0
      api/core/entities/__init__.py
  84. 42
    0
      api/core/entities/parameter_entities.py
  85. 120
    104
      api/core/entities/provider_configuration.py
  86. 79
    3
      api/core/entities/provider_entities.py
  87. 35
    0
      api/core/file/helpers.py
  88. 12
    1
      api/core/file/models.py
  89. 69
    0
      api/core/file/upload_file_parser.py
  90. 17
    0
      api/core/helper/download.py
  91. 35
    0
      api/core/helper/marketplace.py
  92. 22
    10
      api/core/helper/moderation.py
  93. 0
    1
      api/core/helper/ssrf_proxy.py
  94. 1
    0
      api/core/helper/tool_provider_cache.py
  95. 52
    0
      api/core/helper/url_signer.py
  96. 16
    9
      api/core/hosting_configuration.py
  97. 3
    5
      api/core/indexing_runner.py
  98. 8
    8
      api/core/llm_generator/llm_generator.py
  99. 40
    4
      api/core/model_manager.py
  100. 0
    0
      api/core/model_runtime/entities/model_entities.py

+ 3
- 2
.devcontainer/post_create_command.sh View File

#!/bin/bash #!/bin/bash


cd web && npm install
npm add -g pnpm@9.12.2
cd web && pnpm install
pipx install poetry pipx install poetry


echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc



+ 0
- 6
.github/workflows/api-tests.yml View File

- name: Run Unit tests - name: Run Unit tests
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh


- name: Run ModelRuntime
run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh

- name: Run dify config tests - name: Run dify config tests
run: poetry run -P api python dev/pytest/pytest_config_tests.py run: poetry run -P api python dev/pytest/pytest_config_tests.py


- name: Run Tool
run: poetry run -P api bash dev/pytest/pytest_tools.sh

- name: Run mypy - name: Run mypy
run: | run: |
poetry run -C api python -m mypy --install-types --non-interactive . poetry run -C api python -m mypy --install-types --non-interactive .

+ 1
- 0
.github/workflows/db-migration-test.yml View File

pull_request: pull_request:
branches: branches:
- main - main
- plugins/beta
paths: paths:
- api/migrations/** - api/migrations/**
- .github/workflows/db-migration-test.yml - .github/workflows/db-migration-test.yml

+ 8
- 2
.github/workflows/style.yml View File

with: with:
files: web/** files: web/**


- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 10
run_install: false

- name: Setup NodeJS - name: Setup NodeJS
uses: actions/setup-node@v4 uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
with: with:
node-version: 20 node-version: 20
cache: yarn
cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/package.json


- name: Web dependencies - name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: yarn install --frozen-lockfile
run: pnpm install --frozen-lockfile


- name: Web style check - name: Web style check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'

+ 3
- 3
.github/workflows/tool-test-sdks.yaml View File

with: with:
node-version: ${{ matrix.node-version }} node-version: ${{ matrix.node-version }}
cache: '' cache: ''
cache-dependency-path: 'yarn.lock'
cache-dependency-path: 'pnpm-lock.yaml'


- name: Install Dependencies - name: Install Dependencies
run: yarn install
run: pnpm install


- name: Test - name: Test
run: yarn test
run: pnpm test

+ 2
- 2
.github/workflows/translate-i18n-base-on-english.yml View File



- name: Install dependencies - name: Install dependencies
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
run: yarn install --frozen-lockfile
run: pnpm install --frozen-lockfile


- name: Run npm script - name: Run npm script
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
run: npm run auto-gen-i18n
run: pnpm run auto-gen-i18n


- name: Create Pull Request - name: Create Pull Request
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'

+ 3
- 3
.github/workflows/web-tests.yml View File

if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
with: with:
node-version: 20 node-version: 20
cache: yarn
cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/package.json


- name: Install dependencies - name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: yarn install --frozen-lockfile
run: pnpm install --frozen-lockfile


- name: Run tests - name: Run tests
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: yarn test
run: pnpm test

+ 7
- 0
.gitignore View File

docker/volumes/pgvecto_rs/data/* docker/volumes/pgvecto_rs/data/*
docker/volumes/couchbase/* docker/volumes/couchbase/*
docker/volumes/oceanbase/* docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/*
!docker/volumes/oceanbase/init.d !docker/volumes/oceanbase/init.d


docker/nginx/conf.d/default.conf docker/nginx/conf.d/default.conf


.idea/ .idea/
.vscode .vscode

# pnpm
/.pnpm-store

# plugin migrate
plugins.jsonl

+ 5
- 0
api/.dockerignore View File

.env .env
*.env.* *.env.*


storage/generate_files/*
storage/privkeys/* storage/privkeys/*
storage/tools/*
storage/upload_files/*


# Logs # Logs
logs logs


# jetbrains # jetbrains
.idea .idea
.mypy_cache
.ruff_cache


# venv # venv
.venv .venv

+ 16
- 1
api/.env.example View File

APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0



# Celery beat configuration # Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1 CELERY_BEAT_SCHEDULER_TIME=1


POSITION_PROVIDER_INCLUDES= POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES= POSITION_PROVIDER_EXCLUDES=


# Plugin configuration
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1

# Marketplace configuration
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai

# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}

# Reset password token expiry minutes # Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5



+ 4
- 0
api/Dockerfile View File

# Download nltk data # Download nltk data
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')" RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"


ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache

RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"

# Copy source code # Copy source code
COPY . /app/api/ COPY . /app/api/



+ 70
- 3
api/commands.py View File

from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel from models.provider import Provider, ProviderModel
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration




@click.command("reset-password", help="Reset the account password.") @click.command("reset-password", help="Reset the account password.")
) )
) )


except Exception as e:
except Exception:
click.echo(click.style("Failed to create Qdrant client.", fg="red")) click.echo(click.style("Failed to create Qdrant client.", fg="red"))


click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green")) click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green"))


click.echo(click.style("Database migration successful!", fg="green")) click.echo(click.style("Database migration successful!", fg="green"))


except Exception as e:
except Exception:
logging.exception("Failed to execute database migration") logging.exception("Failed to execute database migration")
finally: finally:
lock.release() lock.release()
account = accounts[0] account = accounts[0]
print("Fixing missing site for app {}".format(app.id)) print("Fixing missing site for app {}".format(app.id))
app_was_created.send(app, account=account) app_was_created.send(app, account=account)
except Exception as e:
except Exception:
failed_app_ids.append(app_id) failed_app_ids.append(app_id)
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red")) click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}") logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
break break


click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))


@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
def migrate_data_for_plugin():
"""
Migrate data for plugin.
"""
click.echo(click.style("Starting migrate data for plugin.", fg="white"))

PluginDataMigration.migrate()

click.echo(click.style("Migrate data for plugin completed.", fg="green"))


@click.command("extract-plugins", help="Extract plugins.")
@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)
def extract_plugins(output_file: str, workers: int):
"""
Extract plugins.
"""
click.echo(click.style("Starting extract plugins.", fg="white"))

PluginMigration.extract_plugins(output_file, workers)

click.echo(click.style("Extract plugins completed.", fg="green"))


@click.command("extract-unique-identifiers", help="Extract unique identifiers.")
@click.option(
"--output_file",
prompt=True,
help="The file to store the extracted unique identifiers.",
default="unique_identifiers.json",
)
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
def extract_unique_plugins(output_file: str, input_file: str):
"""
Extract unique plugins.
"""
click.echo(click.style("Starting extract unique plugins.", fg="white"))

PluginMigration.extract_unique_plugins_to_file(input_file, output_file)

click.echo(click.style("Extract unique plugins completed.", fg="green"))


@click.command("install-plugins", help="Install plugins.")
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
@click.option(
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
)
def install_plugins(input_file: str, output_file: str):
"""
Install plugins.
"""
click.echo(click.style("Starting install plugins.", fg="white"))

PluginMigration.install_plugins(input_file, output_file)

click.echo(click.style("Install plugins completed.", fg="green"))

+ 60
- 0
api/configs/feature/__init__.py View File

) )




class PluginConfig(BaseSettings):
"""
Plugin configs
"""

PLUGIN_DAEMON_URL: HttpUrl = Field(
description="Plugin API URL",
default="http://localhost:5002",
)

PLUGIN_DAEMON_KEY: str = Field(
description="Plugin API key",
default="plugin-api-key",
)

INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")

PLUGIN_REMOTE_INSTALL_HOST: str = Field(
description="Plugin Remote Install Host",
default="localhost",
)

PLUGIN_REMOTE_INSTALL_PORT: PositiveInt = Field(
description="Plugin Remote Install Port",
default=5003,
)

PLUGIN_MAX_PACKAGE_SIZE: PositiveInt = Field(
description="Maximum allowed size for plugin packages in bytes",
default=15728640,
)

PLUGIN_MAX_BUNDLE_SIZE: PositiveInt = Field(
description="Maximum allowed size for plugin bundles in bytes",
default=15728640 * 12,
)


class MarketplaceConfig(BaseSettings):
"""
Configuration for marketplace
"""

MARKETPLACE_ENABLED: bool = Field(
description="Enable or disable marketplace",
default=True,
)

MARKETPLACE_API_URL: HttpUrl = Field(
description="Marketplace API URL",
default="https://marketplace.dify.ai",
)


class EndpointConfig(BaseSettings): class EndpointConfig(BaseSettings):
""" """
Configuration for various application endpoints and URLs Configuration for various application endpoints and URLs
default="", default="",
) )


ENDPOINT_URL_TEMPLATE: str = Field(
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
)



class FileAccessConfig(BaseSettings): class FileAccessConfig(BaseSettings):
""" """
AuthConfig, # Changed from OAuthConfig to AuthConfig AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig, BillingConfig,
CodeExecutionSandboxConfig, CodeExecutionSandboxConfig,
PluginConfig,
MarketplaceConfig,
DataSetConfig, DataSetConfig,
EndpointConfig, EndpointConfig,
FileAccessConfig, FileAccessConfig,

+ 1
- 1
api/configs/packaging/__init__.py View File



CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.15.3",
default="1.0.0",
) )


COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

+ 10
- 0
api/contexts/__init__.py View File

from contextvars import ContextVar from contextvars import ContextVar
from threading import Lock
from typing import TYPE_CHECKING from typing import TYPE_CHECKING


if TYPE_CHECKING: if TYPE_CHECKING:
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool



tenant_id: ContextVar[str] = ContextVar("tenant_id") tenant_id: ContextVar[str] = ContextVar("tenant_id")


workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")

plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")

plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")

+ 14
- 2
api/controllers/console/__init__.py View File



from libs.external_api import ExternalApi from libs.external_api import ExternalApi


from .app.app_import import AppImportApi, AppImportConfirmApi
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .explore.audio import ChatAudioApi, ChatTextApi from .explore.audio import ChatAudioApi, ChatTextApi
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from .explore.conversation import ( from .explore.conversation import (
# Import App # Import App
api.add_resource(AppImportApi, "/apps/imports") api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm") api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")


# Import other controllers # Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version from . import admin, apikey, extension, feature, ping, setup, version
from .tag import tags from .tag import tags


# Import workspace controllers # Import workspace controllers
from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace
from .workspace import (
account,
agent_providers,
endpoint,
load_balancing_config,
members,
model_providers,
models,
plugin,
tool_providers,
workspace,
)

+ 23
- 7
api/controllers/console/admin.py View File



from flask import request from flask import request
from flask_restful import Resource, reqparse # type: ignore from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized


from configs import dify_config from configs import dify_config
parser.add_argument("position", type=int, required=True, nullable=False, location="json") parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()


app = App.query.filter(App.id == args["app_id"]).first()
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
if not app: if not app:
raise NotFound(f"App '{args['app_id']}' is not found") raise NotFound(f"App '{args['app_id']}' is not found")


privacy_policy = site.privacy_policy or args["privacy_policy"] or "" privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""


recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()


if not recommended_app: if not recommended_app:
recommended_app = RecommendedApp( recommended_app = RecommendedApp(
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def delete(self, app_id): def delete(self, app_id):
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()

if not recommended_app: if not recommended_app:
return {"result": "success"}, 204 return {"result": "success"}, 204


app = App.query.filter(App.id == recommended_app.app_id).first()
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()

if app: if app:
app.is_public = False app.is_public = False


installed_apps = InstalledApp.query.filter(
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
).all()
with Session(db.engine) as session:
installed_apps = session.execute(
select(InstalledApp).filter(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
).all()


for installed_app in installed_apps: for installed_app in installed_apps:
db.session.delete(installed_app) db.session.delete(installed_app)

+ 12
- 1
api/controllers/console/apikey.py View File

import flask_restful # type: ignore import flask_restful # type: ignore
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal_with from flask_restful import Resource, fields, marshal_with
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden


from extensions.ext_database import db from extensions.ext_database import db




def _get_resource(resource_id, tenant_id, resource_model): def _get_resource(resource_id, tenant_id, resource_model):
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
if resource_model == App:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
else:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()


if resource is None: if resource is None:
flask_restful.abort(404, message=f"{resource_model.__name__} not found.") flask_restful.abort(404, message=f"{resource_model.__name__} not found.")

+ 20
- 1
api/controllers/console/app/app_import.py View File

from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden


from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_import_fields
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required from libs.login import login_required
from models import Account from models import Account
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus from services.app_dsl_service import AppDslService, ImportStatus




if result.status == ImportStatus.FAILED.value: if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200


class AppImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()

with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

return result.model_dump(mode="json"), 200

+ 32
- 27
api/controllers/console/app/site.py View File



from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound


from constants.languages import supported_language from constants.languages import supported_language
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()


site = Site.query.filter(Site.app_id == app_model.id).one_or_404()

for attr_name in [
"title",
"icon_type",
"icon",
"icon_background",
"description",
"default_language",
"chat_color_theme",
"chat_color_theme_inverted",
"customize_domain",
"copyright",
"privacy_policy",
"custom_disclaimer",
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)

site.updated_by = current_user.id
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
with Session(db.engine) as session:
site = session.query(Site).filter(Site.app_id == app_model.id).first()

if not site:
raise NotFound

for attr_name in [
"title",
"icon_type",
"icon",
"icon_background",
"description",
"default_language",
"chat_color_theme",
"chat_color_theme_inverted",
"customize_domain",
"copyright",
"privacy_policy",
"custom_disclaimer",
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)

site.updated_by = current_user.id
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
session.commit()


return site return site



+ 39
- 5
api/controllers/console/app/workflow.py View File

from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models import App from models import App
from models.account import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
else: else:
abort(415) abort(415)


if not isinstance(current_user, Account):
raise Forbidden()

workflow_service = WorkflowService() workflow_service = WorkflowService()


try: try:
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()


if not isinstance(current_user, Account):
raise Forbidden()

parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument("inputs", type=dict, location="json")
parser.add_argument("query", type=str, required=True, location="json", default="") parser.add_argument("query", type=str, required=True, location="json", default="")
raise ConversationCompletedError() raise ConversationCompletedError()
except ValueError as e: except ValueError as e:
raise e raise e
except Exception as e:
except Exception:
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()


if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()


if not isinstance(current_user, Account):
raise Forbidden()

parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
raise ConversationCompletedError() raise ConversationCompletedError()
except ValueError as e: except ValueError as e:
raise e raise e
except Exception as e:
except Exception:
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()


if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()


if not isinstance(current_user, Account):
raise Forbidden()

parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
raise ConversationCompletedError() raise ConversationCompletedError()
except ValueError as e: except ValueError as e:
raise e raise e
except Exception as e:
except Exception:
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()


if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()


if not isinstance(current_user, Account):
raise Forbidden()

parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("files", type=list, required=False, location="json")
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()


if not isinstance(current_user, Account):
raise Forbidden()

parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()


inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")

workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow_node_execution = workflow_service.run_draft_workflow_node( workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
) )


return workflow_node_execution return workflow_node_execution
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()


if not isinstance(current_user, Account):
raise Forbidden()

workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)


if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()


if not isinstance(current_user, Account):
raise Forbidden()

parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args") parser.add_argument("q", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()


q = args.get("q")

filters = None filters = None
if args.get("q"):
if q:
try: try:
filters = json.loads(args.get("q", "")) filters = json.loads(args.get("q", ""))
except json.JSONDecodeError: except json.JSONDecodeError:
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()


if not isinstance(current_user, Account):
raise Forbidden()

if request.data: if request.data:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json") parser.add_argument("name", type=str, required=False, nullable=True, location="json")

+ 7
- 3
api/controllers/console/auth/forgot_password.py View File



from flask import request from flask import request
from flask_restful import Resource, reqparse # type: ignore from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session


from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import api
else: else:
language = "en-US" language = "en-US"


account = Account.query.filter_by(email=args["email"]).first()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = None token = None
if account is None: if account is None:
if FeatureService.get_system_features().is_allow_register: if FeatureService.get_system_features().is_allow_register:
password_hashed = hash_password(new_password, salt) password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()


account = Account.query.filter_by(email=reset_data.get("email")).first()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
if account: if account:
account.password = base64_password_hashed account.password = base64_password_hashed
account.password_salt = base64_salt account.password_salt = base64_salt
) )
except WorkSpaceNotAllowedCreateError: except WorkSpaceNotAllowedCreateError:
pass pass
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()


return {"result": "success"} return {"result": "success"}

+ 4
- 1
api/controllers/console/auth/oauth.py View File

import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restful import Resource # type: ignore from flask_restful import Resource # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized


from configs import dify_config from configs import dify_config
account: Optional[Account] = Account.get_by_openid(provider, user_info.id) account: Optional[Account] = Account.get_by_openid(provider, user_info.id)


if not account: if not account:
account = Account.query.filter_by(email=user_info.email).first()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()


return account return account



+ 63
- 49
api/controllers/console/datasets/data_source.py View File

from flask import request from flask import request
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound


from controllers.console import api from controllers.console import api
def patch(self, binding_id, action): def patch(self, binding_id, action):
binding_id = str(binding_id) binding_id = str(binding_id)
action = str(action) action = str(action)
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
).scalar_one_or_none()
if data_source_binding is None: if data_source_binding is None:
raise NotFound("Data source binding not found.") raise NotFound("Data source binding not found.")
# enable binding # enable binding
def get(self): def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str) dataset_id = request.args.get("dataset_id", default=None, type=str)
exist_page_ids = [] exist_page_ids = []
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")
documents = Document.query.filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
with Session(db.engine) as session:
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")

documents = session.execute(
select(Document).filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = session.scalars(
select(DataSourceOauthBinding).filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
)
).all() ).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = DataSourceOauthBinding.query.filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
).all()
if not data_source_bindings:
return {"notion_info": []}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
pages = source_info["pages"]
# Filter out already bound pages
for page in pages:
if page["page_id"] in exist_page_ids:
page["is_bound"] = True
else:
page["is_bound"] = False
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200
if not data_source_bindings:
return {"notion_info": []}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
pages = source_info["pages"]
# Filter out already bound pages
for page in pages:
if page["page_id"] in exist_page_ids:
page["is_bound"] = True
else:
page["is_bound"] = False
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200




class DataSourceNotionApi(Resource): class DataSourceNotionApi(Resource):
def get(self, workspace_id, page_id, page_type): def get(self, workspace_id, page_id, page_type):
workspace_id = str(workspace_id) workspace_id = str(workspace_id)
page_id = str(page_id) page_id = str(page_id)
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
).first()
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).scalar_one_or_none()
if not data_source_binding: if not data_source_binding:
raise NotFound("Data source binding not found.") raise NotFound("Data source binding not found.")



+ 3
- 0
api/controllers/console/datasets/datasets.py View File

from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting


data = marshal(datasets, dataset_detail_fields) data = marshal(datasets, dataset_detail_fields)
for item in data: for item in data:
# convert embedding_model_provider to plugin standard format
if item["indexing_technique"] == "high_quality": if item["indexing_technique"] == "high_quality":
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names: if item_model in model_names:
item["embedding_available"] = True item["embedding_available"] = True

+ 19
- 3
api/controllers/console/datasets/datasets_document.py View File

from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore
from sqlalchemy import asc, desc from sqlalchemy import asc, desc
from transformers.hf_argparser import string_to_bool # type: ignore
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound


import services import services
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
sort = request.args.get("sort", default="-created_at", type=str) sort = request.args.get("sort", default="-created_at", type=str)
# "yes", "true", "t", "y", "1" convert to True, while others convert to False. # "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try: try:
fetch = string_to_bool(request.args.get("fetch", default="false"))
except (ArgumentTypeError, ValueError, Exception) as e:
fetch_val = request.args.get("fetch", default="false")
if isinstance(fetch_val, bool):
fetch = fetch_val
else:
if fetch_val.lower() in ("yes", "true", "t", "y", "1"):
fetch = True
elif fetch_val.lower() in ("no", "false", "f", "n", "0"):
fetch = False
else:
raise ArgumentTypeError(
f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 "
f"(case insensitive)."
)
except (ArgumentTypeError, ValueError, Exception):
fetch = False fetch = False
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except PluginDaemonClientSideError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))


) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except PluginDaemonClientSideError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))



+ 8
- 1
api/controllers/console/init_validate.py View File



from flask import session from flask import session
from flask_restful import Resource, reqparse # type: ignore from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session


from configs import dify_config from configs import dify_config
from extensions.ext_database import db
from libs.helper import StrLen from libs.helper import StrLen
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
def get_init_validate_status(): def get_init_validate_status():
if dify_config.EDITION == "SELF_HOSTED": if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"): if os.environ.get("INIT_PASSWORD"):
return session.get("is_init_validated") or DifySetup.query.first()
if session.get("is_init_validated"):
return True

with Session(db.engine) as db_session:
return db_session.execute(select(DifySetup)).scalar_one_or_none()


return True return True



+ 4
- 3
api/controllers/console/setup.py View File

from configs import dify_config from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip from libs.helper import StrLen, email, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.model import DifySetup
from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService


from . import api from . import api


def get_setup_status(): def get_setup_status():
if dify_config.EDITION == "SELF_HOSTED": if dify_config.EDITION == "SELF_HOSTED":
return DifySetup.query.first()
return True
return db.session.query(DifySetup).first()
else:
return True




api.add_resource(SetupApi, "/setup") api.add_resource(SetupApi, "/setup")

+ 56
- 0
api/controllers/console/workspace/__init__.py View File

from functools import wraps

from flask_login import current_user # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden

from extensions.ext_database import db
from models.account import TenantPluginPermission


def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
user = current_user
tenant_id = user.current_tenant_id

with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission)
.filter(
TenantPluginPermission.tenant_id == tenant_id,
)
.first()
)

if not permission:
# no permission set, allow access for everyone
return view(*args, **kwargs)

if install_required:
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
pass

if debug_required:
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
pass

return view(*args, **kwargs)

return decorated

return interceptor

+ 36
- 0
api/controllers/console/workspace/agent_providers.py View File

from flask_login import current_user # type: ignore
from flask_restful import Resource # type: ignore

from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.agent_service import AgentService


class AgentProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user

user_id = user.id
tenant_id = user.current_tenant_id

return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))


class AgentProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))


api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")

+ 205
- 0
api/controllers/console/workspace/endpoint.py View File

from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden

from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.plugin.endpoint_service import EndpointService


class EndpointCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()

parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True)
parser.add_argument("settings", type=dict, required=True)
parser.add_argument("name", type=str, required=True)
args = parser.parse_args()

plugin_unique_identifier = args["plugin_unique_identifier"]
settings = args["settings"]
name = args["name"]

return {
"success": EndpointService.create_endpoint(
tenant_id=user.current_tenant_id,
user_id=user.id,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
settings=settings,
)
}


class EndpointListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user

parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
args = parser.parse_args()

page = args["page"]
page_size = args["page_size"]

return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints(
tenant_id=user.current_tenant_id,
user_id=user.id,
page=page,
page_size=page_size,
)
}
)


class EndpointListForSinglePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user

parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
parser.add_argument("plugin_id", type=str, required=True, location="args")
args = parser.parse_args()

page = args["page"]
page_size = args["page_size"]
plugin_id = args["plugin_id"]

return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints_for_single_plugin(
tenant_id=user.current_tenant_id,
user_id=user.id,
plugin_id=plugin_id,
page=page,
page_size=page_size,
)
}
)


class EndpointDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user

parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()

if not user.is_admin_or_owner:
raise Forbidden()

endpoint_id = args["endpoint_id"]

return {
"success": EndpointService.delete_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}


class EndpointUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user

parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
parser.add_argument("settings", type=dict, required=True)
parser.add_argument("name", type=str, required=True)
args = parser.parse_args()

endpoint_id = args["endpoint_id"]
settings = args["settings"]
name = args["name"]

if not user.is_admin_or_owner:
raise Forbidden()

return {
"success": EndpointService.update_endpoint(
tenant_id=user.current_tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=name,
settings=settings,
)
}


class EndpointEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user

parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()

endpoint_id = args["endpoint_id"]

if not user.is_admin_or_owner:
raise Forbidden()

return {
"success": EndpointService.enable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}


class EndpointDisableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user

parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()

endpoint_id = args["endpoint_id"]

if not user.is_admin_or_owner:
raise Forbidden()

return {
"success": EndpointService.disable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}


api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")

+ 2
- 2
api/controllers/console/workspace/load_balancing_config.py View File

# Load Balancing Config # Load Balancing Config
api.add_resource( api.add_resource(
LoadBalancingCredentialsValidateApi, LoadBalancingCredentialsValidateApi,
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate",
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
) )


api.add_resource( api.add_resource(
LoadBalancingConfigCredentialsValidateApi, LoadBalancingConfigCredentialsValidateApi,
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
) )

+ 10
- 45
api/controllers/console/workspace/model_providers.py View File

response = {"result": "success" if result else "error"} response = {"result": "success" if result else "error"}


if not result: if not result:
response["error"] = error
response["error"] = error or "Unknown error"


return response return response


Get model provider icon Get model provider icon
""" """


def get(self, provider: str, icon_type: str, lang: str):
def get(self, tenant_id: str, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon( icon, mimetype = model_provider_service.get_model_provider_icon(
tenant_id=tenant_id,
provider=provider, provider=provider,
icon_type=icon_type, icon_type=icon_type,
lang=lang, lang=lang,
return data return data




class ModelProviderFreeQuotaSubmitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider)

return result


class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=False, nullable=True, location="args")
args = parser.parse_args()

model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"]
)

return result


api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")


api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials")
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
api.add_resource(
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
)
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")


api.add_resource( api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type"
)
api.add_resource(
ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url"
)
api.add_resource(
ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit"
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
) )
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
api.add_resource( api.add_resource(
ModelProviderFreeQuotaQualificationVerifyApi,
"/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify",
ModelProviderIconApi,
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
) )

+ 7
- 7
api/controllers/console/workspace/models.py View File

response = {"result": "success" if result else "error"} response = {"result": "success" if result else "error"}


if not result: if not result:
response["error"] = error
response["error"] = error or ""


return response return response


return jsonable_encoder({"data": models}) return jsonable_encoder({"data": models})




api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
api.add_resource( api.add_resource(
ModelProviderModelEnableApi, ModelProviderModelEnableApi,
"/workspaces/current/model-providers/<string:provider>/models/enable",
"/workspaces/current/model-providers/<path:provider>/models/enable",
endpoint="model-provider-model-enable", endpoint="model-provider-model-enable",
) )
api.add_resource( api.add_resource(
ModelProviderModelDisableApi, ModelProviderModelDisableApi,
"/workspaces/current/model-providers/<string:provider>/models/disable",
"/workspaces/current/model-providers/<path:provider>/models/disable",
endpoint="model-provider-model-disable", endpoint="model-provider-model-disable",
) )
api.add_resource( api.add_resource(
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
) )
api.add_resource( api.add_resource(
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
) )


api.add_resource( api.add_resource(
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
) )
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>") api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
api.add_resource(DefaultModelApi, "/workspaces/current/default-model") api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

+ 475
- 0
api/controllers/console/workspace/plugin.py View File

import io

from flask import request, send_file
from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden

from configs import dify_config
from controllers.console import api
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.manager.exc import PluginDaemonClientSideError
from libs.login import login_required
from models.account import TenantPluginPermission
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService


class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id

try:
return {
"key": PluginService.get_debugging_key(tenant_id),
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
}
except PluginDaemonClientSideError as e:
raise ValueError(e)


class PluginListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
try:
plugins = PluginService.list(tenant_id)
except PluginDaemonClientSideError as e:
raise ValueError(e)

return jsonable_encoder({"plugins": plugins})


class PluginListInstallationsFromIdsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
parser.add_argument("plugin_ids", type=list, required=True, location="json")
args = parser.parse_args()

try:
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
except PluginDaemonClientSideError as e:
raise ValueError(e)

return jsonable_encoder({"plugins": plugins})


class PluginIconApi(Resource):
@setup_required
def get(self):
req = reqparse.RequestParser()
req.add_argument("tenant_id", type=str, required=True, location="args")
req.add_argument("filename", type=str, required=True, location="args")
args = req.parse_args()

try:
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
except PluginDaemonClientSideError as e:
raise ValueError(e)

icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)


class PluginUploadFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id

file = request.files["pkg"]

# check file size
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")

content = file.read()
try:
response = PluginService.upload_pkg(tenant_id, content)
except PluginDaemonClientSideError as e:
raise ValueError(e)

return jsonable_encoder(response)


class PluginUploadFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
args = parser.parse_args()

try:
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
except PluginDaemonClientSideError as e:
raise ValueError(e)

return jsonable_encoder(response)


class PluginUploadFromBundleApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id

file = request.files["bundle"]

# check file size
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")

content = file.read()
try:
response = PluginService.upload_bundle(tenant_id, content)
except PluginDaemonClientSideError as e:
raise ValueError(e)

return jsonable_encoder(response)


class PluginInstallFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
args = parser.parse_args()

# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")

try:
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
except PluginDaemonClientSideError as e:
raise ValueError(e)

return jsonable_encoder(response)


class PluginInstallFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
args = parser.parse_args()

try:
response = PluginService.install_from_github(
tenant_id,
args["plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
)
except PluginDaemonClientSideError as e:
raise ValueError(e)

return jsonable_encoder(response)


class PluginInstallFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
args = parser.parse_args()

# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")

try:
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
except PluginDaemonClientSideError as e:
raise ValueError(e)

return jsonable_encoder(response)


class PluginFetchManifestApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
args = parser.parse_args()

try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_plugin_manifest(
tenant_id, args["plugin_unique_identifier"]
).model_dump()
}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)


class PluginFetchInstallTasksApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
args = parser.parse_args()

try:
return jsonable_encoder(
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)


class PluginFetchInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self, task_id: str):
tenant_id = current_user.current_tenant_id

try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
except PluginDaemonClientSideError as e:
raise ValueError(e)


class PluginDeleteInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self, task_id: str):
tenant_id = current_user.current_tenant_id

try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)


class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id

try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)


class PluginDeleteInstallTaskItemApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id

try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
except PluginDaemonClientSideError as e:
raise ValueError(e)


class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
args = parser.parse_args()

try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_marketplace(
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)


class PluginUpgradeFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
args = parser.parse_args()

try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_github(
tenant_id,
args["original_plugin_unique_identifier"],
args["new_plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)


class PluginUninstallApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
req = reqparse.RequestParser()
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()

tenant_id = current_user.current_tenant_id

try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
except PluginDaemonClientSideError as e:
raise ValueError(e)


class PluginChangePermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()

req = reqparse.RequestParser()
req.add_argument("install_permission", type=str, required=True, location="json")
req.add_argument("debug_permission", type=str, required=True, location="json")
args = req.parse_args()

install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])

tenant_id = user.current_tenant_id

return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}


class PluginFetchPermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id

permission = PluginPermissionService.get_permission(tenant_id)
if not permission:
return jsonable_encoder(
{
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
}
)

return jsonable_encoder(
{
"install_permission": permission.install_permission,
"debug_permission": permission.debug_permission,
}
)


api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")

api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")

+ 108
- 52
api/controllers/console/workspace/tool_providers.py View File

@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user

user_id = user.id
tenant_id = user.current_tenant_id


req = reqparse.RequestParser() req = reqparse.RequestParser()
req.add_argument( req.add_argument(
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user

tenant_id = user.current_tenant_id


return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools( BuiltinToolManageService.list_builtin_tool_provider_tools(
user_id,
tenant_id, tenant_id,
provider, provider,
) )
) )




class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
user = current_user

user_id = user.id
tenant_id = user.current_tenant_id

return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))


class ToolBuiltinProviderDeleteApi(Resource): class ToolBuiltinProviderDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
if not current_user.is_admin_or_owner:
user = current_user

if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()


user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id


return BuiltinToolManageService.delete_builtin_tool_provider( return BuiltinToolManageService.delete_builtin_tool_provider(
user_id, user_id,
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
if not current_user.is_admin_or_owner:
user = current_user

if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()


user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id


parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner:
user = current_user

if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()


user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id


parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user

user_id = user.id
tenant_id = user.current_tenant_id

parser = reqparse.RequestParser() parser = reqparse.RequestParser()


parser.add_argument("url", type=str, required=True, nullable=False, location="args") parser.add_argument("url", type=str, required=True, nullable=False, location="args")
args = parser.parse_args() args = parser.parse_args()


return ApiToolManageService.get_api_tool_provider_remote_schema( return ApiToolManageService.get_api_tool_provider_remote_schema(
current_user.id,
current_user.current_tenant_id,
user_id,
tenant_id,
args["url"], args["url"],
) )


@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user

user_id = user.id
tenant_id = user.current_tenant_id


parser = reqparse.RequestParser() parser = reqparse.RequestParser()


@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner:
user = current_user

if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()


user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id


parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner:
user = current_user

if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()


user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id


parser = reqparse.RequestParser() parser = reqparse.RequestParser()


@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user

user_id = user.id
tenant_id = user.current_tenant_id


parser = reqparse.RequestParser() parser = reqparse.RequestParser()


@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
user = current_user

tenant_id = user.current_tenant_id

return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)




class ToolApiProviderSchemaApi(Resource): class ToolApiProviderSchemaApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner:
user = current_user

if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()


user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id


reqparser = reqparse.RequestParser() reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner:
user = current_user

if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()


user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id


reqparser = reqparse.RequestParser() reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not current_user.is_admin_or_owner:
user = current_user

if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()


user_id = current_user.id
tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id


reqparser = reqparse.RequestParser() reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user

user_id = user.id
tenant_id = user.current_tenant_id


parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user

user_id = user.id
tenant_id = user.current_tenant_id


parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user

user_id = user.id
tenant_id = user.current_tenant_id


return jsonable_encoder( return jsonable_encoder(
[ [
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user

user_id = user.id
tenant_id = user.current_tenant_id


return jsonable_encoder( return jsonable_encoder(
[ [
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
user = current_user

user_id = user.id
tenant_id = user.current_tenant_id


return jsonable_encoder( return jsonable_encoder(
[ [
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")


# builtin tool provider # builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update")
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource( api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials"
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
) )
api.add_resource( api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema"
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
) )
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon")
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")


# api tool provider # api tool provider
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")

+ 7
- 2
api/controllers/console/wraps.py View File



from configs import dify_config from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from models.model import DifySetup from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService from services.operation_service import OperationService
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
# check setup # check setup
if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first():
if (
dify_config.EDITION == "SELF_HOSTED"
and os.environ.get("INIT_PASSWORD")
and not db.session.query(DifySetup).first()
):
raise NotInitValidateError() raise NotInitValidateError()
elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first():
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
raise NotSetupError() raise NotSetupError()


return view(*args, **kwargs) return view(*args, **kwargs)

+ 1
- 1
api/controllers/files/__init__.py View File

api = ExternalApi(bp) api = ExternalApi(bp)




from . import image_preview, tool_files
from . import image_preview, tool_files, upload

+ 69
- 0
api/controllers/files/upload.py View File

from flask import request
from flask_restful import Resource, marshal_with # type: ignore
from werkzeug.exceptions import Forbidden

import services
from controllers.console.wraps import setup_required
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
from controllers.inner_api.plugin.wraps import get_user
from controllers.service_api.app.error import FileTooLargeError
from core.file.helpers import verify_plugin_file_signature
from fields.file_fields import file_fields
from services.file_service import FileService


class PluginUploadFileApi(Resource):
@setup_required
@marshal_with(file_fields)
def post(self):
# get file from request
file = request.files["file"]

timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
sign = request.args.get("sign")
tenant_id = request.args.get("tenant_id")
if not tenant_id:
raise Forbidden("Invalid request.")

user_id = request.args.get("user_id")
user = get_user(tenant_id, user_id)

filename = file.filename
mimetype = file.mimetype

if not filename or not mimetype:
raise Forbidden("Invalid request.")

if not timestamp or not nonce or not sign:
raise Forbidden("Invalid request.")

if not verify_plugin_file_signature(
filename=filename,
mimetype=mimetype,
tenant_id=tenant_id,
user_id=user_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
):
raise Forbidden("Invalid request.")

try:
upload_file = FileService.upload_file(
filename=filename,
content=file.read(),
mimetype=mimetype,
user=user,
source=None,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()

return upload_file, 201


api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")

+ 1
- 0
api/controllers/inner_api/__init__.py View File

bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
api = ExternalApi(bp) api = ExternalApi(bp)


from .plugin import plugin
from .workspace import workspace from .workspace import workspace

api/core/model_runtime/model_providers/anthropic/__init__.py → api/controllers/inner_api/plugin/__init__.py View File


+ 293
- 0
api/controllers/inner_api/plugin/plugin.py View File

from flask_restful import Resource # type: ignore

from controllers.console.wraps import setup_required
from controllers.inner_api import api
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only
from core.file.helpers import get_signed_file_url_for_plugin
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
from core.plugin.backwards_invocation.encrypt import PluginEncrypter
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import (
RequestInvokeApp,
RequestInvokeEncrypt,
RequestInvokeLLM,
RequestInvokeModeration,
RequestInvokeParameterExtractorNode,
RequestInvokeQuestionClassifierNode,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
RequestInvokeSummary,
RequestInvokeTextEmbedding,
RequestInvokeTool,
RequestInvokeTTS,
RequestRequestUploadFile,
)
from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import compact_generate_response
from models.account import Account, Tenant
from models.model import EndUser


class PluginInvokeLLMApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLM)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
def generator():
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)

return compact_generate_response(generator())


class PluginInvokeTextEmbeddingApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTextEmbedding)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_text_embedding(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))


class PluginInvokeRerankApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeRerank)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_rerank(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))


class PluginInvokeTTSApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTTS)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS):
def generator():
response = PluginModelBackwardsInvocation.invoke_tts(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)

return compact_generate_response(generator())


class PluginInvokeSpeech2TextApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSpeech2Text)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_speech2text(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))


class PluginInvokeModerationApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeModeration)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_moderation(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))


class PluginInvokeToolApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeTool)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
def generator():
return PluginToolBackwardsInvocation.convert_to_event_stream(
PluginToolBackwardsInvocation.invoke_tool(
tenant_id=tenant_model.id,
user_id=user_model.id,
tool_type=ToolProviderType.value_of(payload.tool_type),
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
),
)

return compact_generate_response(generator())


class PluginInvokeParameterExtractorNodeApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginNodeBackwardsInvocation.invoke_parameter_extractor(
tenant_id=tenant_model.id,
user_id=user_model.id,
parameters=payload.parameters,
model_config=payload.model,
instruction=payload.instruction,
query=payload.query,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))


class PluginInvokeQuestionClassifierNodeApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
try:
return jsonable_encoder(
BaseBackwardsInvocationResponse(
data=PluginNodeBackwardsInvocation.invoke_question_classifier(
tenant_id=tenant_model.id,
user_id=user_model.id,
query=payload.query,
model_config=payload.model,
classes=payload.classes,
instruction=payload.instruction,
)
)
)
except Exception as e:
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))


class PluginInvokeAppApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeApp)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
response = PluginAppBackwardsInvocation.invoke_app(
app_id=payload.app_id,
user_id=user_model.id,
tenant_id=tenant_model.id,
conversation_id=payload.conversation_id,
query=payload.query,
stream=payload.response_mode == "streaming",
inputs=payload.inputs,
files=payload.files,
)

return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))


class PluginInvokeEncryptApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeEncrypt)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt):
"""
encrypt or decrypt data
"""
try:
return BaseBackwardsInvocationResponse(
data=PluginEncrypter.invoke_encrypt(tenant_model, payload)
).model_dump()
except Exception as e:
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()


class PluginInvokeSummaryApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeSummary)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary):
try:
return BaseBackwardsInvocationResponse(
data={
"summary": PluginModelBackwardsInvocation.invoke_summary(
user_id=user_model.id,
tenant=tenant_model,
payload=payload,
)
}
).model_dump()
except Exception as e:
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()


class PluginUploadFileRequestApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestRequestUploadFile)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
# generate signed url
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()


api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
api.add_resource(PluginInvokeToolApi, "/invoke/tool")
api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")

+ 116
- 0
api/controllers/inner_api/plugin/wraps.py View File

from collections.abc import Callable
from functools import wraps
from typing import Optional

from flask import request
from flask_restful import reqparse # type: ignore
from pydantic import BaseModel
from sqlalchemy.orm import Session

from extensions.ext_database import db
from models.account import Account, Tenant
from models.model import EndUser
from services.account_service import AccountService


def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
try:
with Session(db.engine) as session:
if not user_id:
user_id = "DEFAULT-USER"

if user_id == "DEFAULT-USER":
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
type="service_api",
is_anonymous=True if user_id == "DEFAULT-USER" else False,
session_id=user_id,
)
session.add(user_model)
session.commit()
else:
user_model = AccountService.load_user(user_id)
if not user_model:
user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
if not user_model:
raise ValueError("user not found")
except Exception:
raise ValueError("user not found")

return user_model


def get_user_tenant(view: Optional[Callable] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
# fetch json body
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
parser.add_argument("user_id", type=str, required=True, location="json")

kwargs = parser.parse_args()

user_id = kwargs.get("user_id")
tenant_id = kwargs.get("tenant_id")

if not tenant_id:
raise ValueError("tenant_id is required")

if not user_id:
user_id = "DEFAULT-USER"

del kwargs["tenant_id"]
del kwargs["user_id"]

try:
tenant_model = (
db.session.query(Tenant)
.filter(
Tenant.id == tenant_id,
)
.first()
)
except Exception:
raise ValueError("tenant not found")

if not tenant_model:
raise ValueError("tenant not found")

kwargs["tenant_model"] = tenant_model
kwargs["user_model"] = get_user(tenant_id, user_id)

return view_func(*args, **kwargs)

return decorated_view

if view is None:
return decorator
else:
return decorator(view)


def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
def decorator(view_func):
def decorated_view(*args, **kwargs):
try:
data = request.get_json()
except Exception:
raise ValueError("invalid json")

try:
payload = payload_type(**data)
except Exception as e:
raise ValueError(f"invalid payload: {str(e)}")

kwargs["payload"] = payload
return view_func(*args, **kwargs)

return decorated_view

if view is None:
return decorator
else:
return decorator(view)

+ 3
- 3
api/controllers/inner_api/workspace/workspace.py View File



from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from controllers.inner_api import api from controllers.inner_api import api
from controllers.inner_api.wraps import inner_api_only
from controllers.inner_api.wraps import enterprise_inner_api_only
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from models.account import Account from models.account import Account
from services.account_service import TenantService from services.account_service import TenantService


class EnterpriseWorkspace(Resource): class EnterpriseWorkspace(Resource):
@setup_required @setup_required
@inner_api_only
@enterprise_inner_api_only
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")


class EnterpriseWorkspaceNoOwnerEmail(Resource): class EnterpriseWorkspaceNoOwnerEmail(Resource):
@setup_required @setup_required
@inner_api_only
@enterprise_inner_api_only
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")

+ 19
- 3
api/controllers/inner_api/wraps.py View File

from models.model import EndUser from models.model import EndUser




def inner_api_only(view):
def enterprise_inner_api_only(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if not dify_config.INNER_API: if not dify_config.INNER_API:


# get header 'X-Inner-Api-Key' # get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key") inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
abort(401) abort(401)


return view(*args, **kwargs) return view(*args, **kwargs)
return decorated return decorated




def inner_api_user_auth(view):
def enterprise_inner_api_user_auth(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if not dify_config.INNER_API: if not dify_config.INNER_API:
return view(*args, **kwargs) return view(*args, **kwargs)


return decorated return decorated


def plugin_inner_api_only(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)

# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
abort(404)

return view(*args, **kwargs)

return decorated

+ 44
- 70
api/core/agent/base_agent_runner.py View File

import json import json
import logging import logging
import uuid import uuid
from datetime import UTC, datetime
from typing import Optional, Union, cast from typing import Optional, Union, cast


from core.agent.entities import AgentEntity, AgentToolEntity from core.agent.entities import AgentEntity, AgentToolEntity
from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolParameter, ToolParameter,
ToolRuntimeVariablePool,
) )
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models.model import Conversation, Message, MessageAgentThought, MessageFile from models.model import Conversation, Message, MessageAgentThought, MessageFile
from models.tools import ToolConversationVariables


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


queue_manager: AppQueueManager, queue_manager: AppQueueManager,
message: Message, message: Message,
user_id: str, user_id: str,
model_instance: ModelInstance,
memory: Optional[TokenBufferMemory] = None, memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None, prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance,
) -> None: ) -> None:
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.user_id = user_id self.user_id = user_id
self.memory = memory self.memory = memory
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance self.model_instance = model_instance


# init callback # init callback
agent_tool=tool, agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
) )
tool_entity.load_variables(self.variables_pool)

assert tool_entity.entity.description
message_tool = PromptMessageTool( message_tool = PromptMessageTool(
name=tool.tool_name, name=tool.tool_name,
description=tool_entity.description.llm if tool_entity.description else "",
description=tool_entity.entity.description.llm,
parameters={ parameters={
"type": "object", "type": "object",
"properties": {}, "properties": {},
}, },
) )


parameters = tool_entity.get_all_runtime_parameters()
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters: for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM: if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue continue
""" """
convert dataset retriever tool to prompt message tool convert dataset retriever tool to prompt message tool
""" """
assert tool.entity.description

prompt_tool = PromptMessageTool( prompt_tool = PromptMessageTool(
name=tool.identity.name if tool.identity else "unknown",
description=tool.description.llm if tool.description else "",
name=tool.entity.identity.name,
description=tool.entity.description.llm,
parameters={ parameters={
"type": "object", "type": "object",
"properties": {}, "properties": {},
# save prompt tool # save prompt tool
prompt_messages_tools.append(prompt_tool) prompt_messages_tools.append(prompt_tool)
# save tool entity # save tool entity
if dataset_tool.identity is not None:
tool_instances[dataset_tool.identity.name] = dataset_tool
tool_instances[dataset_tool.entity.identity.name] = dataset_tool


return tool_instances, prompt_messages_tools return tool_instances, prompt_messages_tools


def save_agent_thought( def save_agent_thought(
self, self,
agent_thought: MessageAgentThought, agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
tool_name: str | None,
tool_input: Union[str, dict, None],
thought: str | None,
observation: Union[str, dict, None], observation: Union[str, dict, None],
tool_invoke_meta: Union[str, dict, None], tool_invoke_meta: Union[str, dict, None],
answer: str,
answer: str | None,
messages_ids: list[str], messages_ids: list[str],
llm_usage: LLMUsage | None = None, llm_usage: LLMUsage | None = None,
): ):
""" """
Save agent thought Save agent thought
""" """
queried_thought = (
updated_agent_thought = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
) )
if not queried_thought:
raise ValueError(f"Agent thought {agent_thought.id} not found")
agent_thought = queried_thought
if not updated_agent_thought:
raise ValueError("agent thought not found")
agent_thought = updated_agent_thought


if thought: if thought:
agent_thought.thought = thought agent_thought.thought = thought
if isinstance(tool_input, dict): if isinstance(tool_input, dict):
try: try:
tool_input = json.dumps(tool_input, ensure_ascii=False) tool_input = json.dumps(tool_input, ensure_ascii=False)
except Exception as e:
except Exception:
tool_input = json.dumps(tool_input) tool_input = json.dumps(tool_input)


agent_thought.tool_input = tool_input
updated_agent_thought.tool_input = tool_input


if observation: if observation:
if isinstance(observation, dict): if isinstance(observation, dict):
try: try:
observation = json.dumps(observation, ensure_ascii=False) observation = json.dumps(observation, ensure_ascii=False)
except Exception as e:
except Exception:
observation = json.dumps(observation) observation = json.dumps(observation)


agent_thought.observation = observation
updated_agent_thought.observation = observation


if answer: if answer:
agent_thought.answer = answer agent_thought.answer = answer


if messages_ids is not None and len(messages_ids) > 0: if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids)
updated_agent_thought.message_files = json.dumps(messages_ids)


if llm_usage: if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit
agent_thought.message_unit_price = llm_usage.prompt_unit_price
agent_thought.answer_token = llm_usage.completion_tokens
agent_thought.answer_price_unit = llm_usage.completion_price_unit
agent_thought.answer_unit_price = llm_usage.completion_unit_price
agent_thought.tokens = llm_usage.total_tokens
agent_thought.total_price = llm_usage.total_price
updated_agent_thought.message_token = llm_usage.prompt_tokens
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
updated_agent_thought.answer_token = llm_usage.completion_tokens
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
updated_agent_thought.tokens = llm_usage.total_tokens
updated_agent_thought.total_price = llm_usage.total_price


# check if tool labels is not empty # check if tool labels is not empty
labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(";") if agent_thought.tool else []
labels = updated_agent_thought.tool_labels or {}
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
for tool in tools: for tool in tools:
if not tool: if not tool:
continue continue
else: else:
labels[tool] = {"en_US": tool, "zh_Hans": tool} labels[tool] = {"en_US": tool, "zh_Hans": tool}


agent_thought.tool_labels_str = json.dumps(labels)
updated_agent_thought.tool_labels_str = json.dumps(labels)


if tool_invoke_meta is not None: if tool_invoke_meta is not None:
if isinstance(tool_invoke_meta, dict): if isinstance(tool_invoke_meta, dict):
try: try:
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
except Exception as e:
except Exception:
tool_invoke_meta = json.dumps(tool_invoke_meta) tool_invoke_meta = json.dumps(tool_invoke_meta)


agent_thought.tool_meta_str = tool_invoke_meta

db.session.commit()
db.session.close()

def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
queried_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)

if not queried_variables:
return
updated_agent_thought.tool_meta_str = tool_invoke_meta


db_variables = queried_variables

db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit() db.session.commit()
db.session.close() db.session.close()


tool_call_response: list[ToolPromptMessage] = [] tool_call_response: list[ToolPromptMessage] = []
try: try:
tool_inputs = json.loads(agent_thought.tool_input) tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e:
except Exception:
tool_inputs = {tool: {} for tool in tools} tool_inputs = {tool: {} for tool in tools}
try: try:
tool_responses = json.loads(agent_thought.observation) tool_responses = json.loads(agent_thought.observation)
except Exception as e:
except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation) tool_responses = dict.fromkeys(tools, agent_thought.observation)


for tool in tools: for tool in tools:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if not files: if not files:
return UserPromptMessage(content=message.query) return UserPromptMessage(content=message.query)
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if message.app_model_config:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
else:
file_extra_config = None

if not file_extra_config: if not file_extra_config:
return UserPromptMessage(content=message.query) return UserPromptMessage(content=message.query)



+ 31
- 43
api/core/agent/cot_agent_runner.py View File

import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional from typing import Any, Optional


from core.agent.base_agent_runner import BaseAgentRunner from core.agent.base_agent_runner import BaseAgentRunner
) )
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from models.model import Message from models.model import Message


class CotAgentRunner(BaseAgentRunner, ABC): class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True _is_first_iteration = True
_ignore_observation_providers = ["wenxin"] _ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] | None = None
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
_instruction: str = "" # FIXME this must be str for now
_query: str | None = None
_prompt_messages_tools: list[PromptMessageTool] = []
_historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str
_query: str
_prompt_messages_tools: Sequence[PromptMessageTool]


def run( def run(
self, self,
""" """
Run Cot agent application Run Cot agent application
""" """

app_generate_entity = self.application_generate_entity app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity) self._repack_app_generate_entity(app_generate_entity)
self._init_react_state(query) self._init_react_state(query)
app_generate_entity.model_conf.stop.append("Observation") app_generate_entity.model_conf.stop.append("Observation")


app_config = self.app_config app_config = self.app_config
assert app_config.agent


# init instruction # init instruction
inputs = inputs or {} inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
instruction = app_config.prompt_template.simple_prompt_template or ""
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)


iteration_step = 1 iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1


# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools


function_call_state = True function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
callbacks=[], callbacks=[],
) )


if not isinstance(chunks, Generator):
raise ValueError("Expected streaming response from LLM")

# check llm result
if not chunks:
raise ValueError("failed to invoke llm")

usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
usage_dict: dict[str, Optional[LLMUsage]] = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit( scratchpad = AgentScratchpadUnit(
agent_response="", agent_response="",
if isinstance(chunk, AgentScratchpadUnit.Action): if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk action = chunk
# detect action # detect action
if scratchpad.agent_response is not None:
scratchpad.agent_response += json.dumps(chunk.model_dump())
assert scratchpad.agent_response is not None
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump()) scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action scratchpad.action = action
else: else:
if scratchpad.agent_response is not None:
scratchpad.agent_response += chunk
if scratchpad.thought is not None:
scratchpad.thought += chunk
assert scratchpad.agent_response is not None
scratchpad.agent_response += chunk
assert scratchpad.thought is not None
scratchpad.thought += chunk
yield LLMResultChunk( yield LLMResultChunk(
model=self.model_config.model, model=self.model_config.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
system_fingerprint="", system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
) )
if scratchpad.thought is not None:
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
if self._agent_scratchpad is not None:
self._agent_scratchpad.append(scratchpad)
assert scratchpad.thought is not None
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)


# get llm usage # get llm usage
if "usage" in usage_dict: if "usage" in usage_dict:
answer=final_answer, answer=final_answer,
messages_ids=[], messages_ids=[],
) )
if self.variables_pool is not None and self.db_variables_pool is not None:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageEndEvent( QueueMessageEndEvent(
def _handle_invoke_action( def _handle_invoke_action(
self, self,
action: AgentScratchpadUnit.Action, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
tool_instances: Mapping[str, Tool],
message_file_ids: list[str], message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None, trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]: ) -> tuple[str, ToolInvokeMeta]:
) )


# publish files # publish files
for message_file_id, save_as in message_files:
if save_as is not None and self.variables_pool:
# FIXME the save_as type is confusing, it should be a string or not
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))

for message_file_id in message_files:
# publish message file # publish message file
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
for key, value in inputs.items(): for key, value in inputs.items():
try: try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception as e:
except Exception:
continue continue


return instruction return instruction
return message return message


def _organize_historic_prompt_messages( def _organize_historic_prompt_messages(
self, current_session_messages: Optional[list[PromptMessage]] = None
self, current_session_messages: list[PromptMessage] | None = None
) -> list[PromptMessage]: ) -> list[PromptMessage]:
""" """
organize historic prompt messages organize historic prompt messages
for message in self.history_prompt_messages: for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage): if isinstance(message, AssistantPromptMessage):
if not current_scratchpad: if not current_scratchpad:
if not isinstance(message.content, str | None):
raise NotImplementedError("expected str type")
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit( current_scratchpad = AgentScratchpadUnit(
agent_response=message.content, agent_response=message.content,
thought=message.content or "I am thinking about how to help you", thought=message.content or "I am thinking about how to help you",
except: except:
pass pass
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
if not current_scratchpad:
continue
if isinstance(message.content, str):
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content current_scratchpad.observation = message.content
else: else:
raise NotImplementedError("expected str type") raise NotImplementedError("expected str type")

+ 4
- 2
api/core/agent/cot_chat_agent_runner.py View File

""" """
Organize system prompt Organize system prompt
""" """
if not self.app_config.agent:
raise ValueError("Agent configuration is not set")
assert self.app_config.agent
assert self.app_config.agent.prompt


prompt_entity = self.app_config.agent.prompt prompt_entity = self.app_config.agent.prompt
if not prompt_entity: if not prompt_entity:
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad: for unit in agent_scratchpad:
if unit.is_final(): if unit.is_final():
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}" assistant_message.content += f"Final Answer: {unit.agent_response}"
else: else:
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n" assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str: if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n" assistant_message.content += f"Action: {unit.action_str}\n\n"

+ 16
- 5
api/core/agent/entities.py View File

from enum import Enum
from typing import Any, Literal, Optional, Union
from enum import StrEnum
from typing import Any, Optional, Union


from pydantic import BaseModel from pydantic import BaseModel


from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType



class AgentToolEntity(BaseModel): class AgentToolEntity(BaseModel):
""" """
Agent Tool Entity. Agent Tool Entity.
""" """


provider_type: Literal["builtin", "api", "workflow"]
provider_type: ToolProviderType
provider_id: str provider_id: str
tool_name: str tool_name: str
tool_parameters: dict[str, Any] = {} tool_parameters: dict[str, Any] = {}
plugin_unique_identifier: str | None = None




class AgentPromptEntity(BaseModel): class AgentPromptEntity(BaseModel):
Agent Entity. Agent Entity.
""" """


class Strategy(Enum):
class Strategy(StrEnum):
""" """
Agent Strategy. Agent Strategy.
""" """
model: str model: str
strategy: Strategy strategy: Strategy
prompt: Optional[AgentPromptEntity] = None prompt: Optional[AgentPromptEntity] = None
tools: list[AgentToolEntity] | None = None
tools: Optional[list[AgentToolEntity]] = None
max_iteration: int = 5 max_iteration: int = 5


class AgentInvokeMessage(ToolInvokeMessage):
"""
Agent Invoke Message.
"""

pass

+ 15
- 24
api/core/agent/fc_agent_runner.py View File

# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools() tool_instances, prompt_messages_tools = self._init_prompt_tools()


assert app_config.agent

iteration_step = 1 iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1


# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = True function_call_state = True
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = "" final_answer = ""


# get tracing instance # get tracing instance
trace_manager = app_generate_entity.trace_manager trace_manager = app_generate_entity.trace_manager


def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]: if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage final_llm_usage_dict["usage"] = usage
else: else:


current_llm_usage = None current_llm_usage = None


if self.stream_tool_call and isinstance(chunks, Generator):
if isinstance(chunks, Generator):
is_first_chunk = True is_first_chunk = True
for chunk in chunks: for chunk in chunks:
if is_first_chunk: if is_first_chunk:
tool_call_inputs = json.dumps( tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
) )
except json.JSONDecodeError as e:
except json.JSONDecodeError:
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})


current_llm_usage = chunk.delta.usage current_llm_usage = chunk.delta.usage


yield chunk yield chunk
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
else:
result = chunks result = chunks
# check if there is any tool call # check if there is any tool call
if self.check_blocking_tool_calls(result): if self.check_blocking_tool_calls(result):
tool_call_inputs = json.dumps( tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
) )
except json.JSONDecodeError as e:
except json.JSONDecodeError:
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})


usage=result.usage, usage=result.usage,
), ),
) )
else:
raise RuntimeError(f"invalid chunks type: {type(chunks)}")


assistant_message = AssistantPromptMessage(content="", tool_calls=[]) assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls: if tool_calls:
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback, agent_tool_callback=self.agent_callback,
trace_manager=trace_manager, trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=self.message.id,
conversation_id=self.conversation.id,
) )
# publish files # publish files
for message_file_id, save_as in message_files:
if save_as:
if self.variables_pool:
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as
)

for message_file_id in message_files:
# publish message file # publish message file
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER


iteration_step += 1 iteration_step += 1


if self.variables_pool and self.db_variables_pool:
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageEndEvent( QueueMessageEndEvent(
return True return True
return False return False


def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
""" """
Extract tool calls from llm result chunk Extract tool calls from llm result chunk




return tool_calls return tool_calls


def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
""" """
Extract blocking tool calls from llm result Extract blocking tool calls from llm result




return tool_calls return tool_calls


def _init_system_message(
self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
) -> list[PromptMessage]:
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
""" """
Initialize system message Initialize system message
""" """

+ 89
- 0
api/core/agent/plugin_entities.py View File

import enum
from typing import Any, Optional

from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator

from core.entities.parameter_entities import CommonParameterType
from core.plugin.entities.parameters import (
PluginParameter,
as_normal_type,
cast_parameter_value,
init_frontend_parameter,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolProviderIdentity,
)


class AgentStrategyProviderIdentity(ToolProviderIdentity):
"""
Inherits from ToolProviderIdentity, without any additional fields.
"""

pass


class AgentStrategyParameter(PluginParameter):
class AgentStrategyParameterType(enum.StrEnum):
"""
Keep all the types from PluginParameterType
"""

STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value

# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value

def as_normal_type(self):
return as_normal_type(self)

def cast_value(self, value: Any):
return cast_parameter_value(self, value)

type: AgentStrategyParameterType = Field(..., description="The type of the parameter")

def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)


class AgentStrategyProviderEntity(BaseModel):
identity: AgentStrategyProviderIdentity
plugin_id: Optional[str] = Field(None, description="The id of the plugin")


class AgentStrategyIdentity(ToolIdentity):
"""
Inherits from ToolIdentity, without any additional fields.
"""

pass


class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None

# pydantic configs
model_config = ConfigDict(protected_namespaces=())

@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
return v or []


class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
strategies: list[AgentStrategyEntity] = Field(default_factory=list)

+ 42
- 0
api/core/agent/strategy/base.py View File

from abc import ABC, abstractmethod
from collections.abc import Generator, Sequence
from typing import Any, Optional

from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyParameter


class BaseAgentStrategy(ABC):
"""
Agent Strategy
"""

def invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)

def get_parameters(self) -> Sequence[AgentStrategyParameter]:
"""
Get the parameters for the agent strategy.
"""
return []

@abstractmethod
def _invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
pass

+ 59
- 0
api/core/agent/strategy/plugin.py View File

from collections.abc import Generator, Sequence
from typing import Any, Optional

from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy
from core.plugin.manager.agent import PluginAgentManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format


class PluginAgentStrategy(BaseAgentStrategy):
"""
Agent Strategy
"""

tenant_id: str
declaration: AgentStrategyEntity

def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
self.tenant_id = tenant_id
self.declaration = declaration

def get_parameters(self) -> Sequence[AgentStrategyParameter]:
return self.declaration.parameters

def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]:
"""
Initialize the parameters for the agent strategy.
"""
for parameter in self.declaration.parameters:
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
return params

def _invoke(
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
manager = PluginAgentManager()

initialized_params = self.initialize_parameters(params)
params = convert_parameters_to_plugin_format(initialized_params)

yield from manager.invoke(
tenant_id=self.tenant_id,
user_id=user_id,
agent_provider=self.declaration.identity.provider,
agent_strategy=self.declaration.identity.name,
agent_params=params,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)

+ 7
- 6
api/core/app/app_config/easy_ui_based_app/model_config/converter.py View File

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager


stop = completion_params["stop"] stop = completion_params["stop"]
del completion_params["stop"] del completion_params["stop"]


model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)

# get model mode # get model mode
model_mode = model_config.mode model_mode = model_config.mode
if not model_mode: if not model_mode:
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)

model_mode = mode_enum.value

model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
model_mode = LLMMode.CHAT.value
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value


if not model_schema: if not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")

+ 12
- 2
api/core/app/app_config/easy_ui_based_app/model_config/manager.py View File

from typing import Any from typing import Any


from core.app.app_config.entities import ModelConfigEntity from core.app.app_config.entities import ModelConfigEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager




raise ValueError("model must be of object type") raise ValueError("model must be of object type")


# model.provider # model.provider
model_provider_factory = ModelProviderFactory(tenant_id)
provider_entities = model_provider_factory.get_providers() provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities] model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
if "provider" not in config["model"]:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")

if "/" not in config["model"]["provider"]:
config["model"]["provider"] = (
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
)

if config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")


# model.name # model.name

+ 23
- 13
api/core/app/apps/advanced_chat/app_generator.py View File

user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
streaming: Literal[False],
) -> Mapping[str, Any]: ...


@overload @overload
def generate( def generate(
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Mapping[str, Any],
args: Mapping,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
streaming: Literal[True],
) -> Generator[Mapping | str, None, None]: ...


@overload @overload
def generate( def generate(
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Mapping[str, Any],
args: Mapping,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
streaming: bool,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...


def generate( def generate(
self, self,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Mapping[str, Any],
args: Mapping,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
):
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
""" """
Generate App response. Generate App response.


workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())


return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
) )


def single_iteration_generate( def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True
) -> Mapping[str, Any] | Generator[str, None, None]:
self,
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
""" """
Generate App response. Generate App response.


), ),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())


return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None, conversation: Optional[Conversation] = None,
stream: bool = True, stream: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
""" """
Generate App response. Generate App response.



+ 2
- 2
api/core/app/apps/advanced_chat/app_generator_tts_publisher.py View File





class AppGeneratorTTSPublisher: class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str):
def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.msg_text = "" self.msg_text = ""
self.model_instance = self.model_manager.get_default_model_instance( self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.TTS tenant_id=self.tenant_id, model_type=ModelType.TTS
) )
self.voices = self.model_instance.get_tts_voices()
self.voices = self.model_instance.get_tts_voices(language=language)
values = [voice.get("value") for voice in self.voices] values = [voice.get("value") for voice in self.voices]
self.voice = voice self.voice = voice
if not voice or voice not in values: if not voice or voice not in values:

+ 1
- 1
api/core/app/apps/advanced_chat/app_runner.py View File

graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

+ 4
- 5
api/core/app/apps/advanced_chat/generate_response_converter.py View File

import json
from collections.abc import Generator from collections.abc import Generator
from typing import Any, cast from typing import Any, cast


@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
) -> Generator[dict | str, Any, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk


@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
) -> Generator[dict | str, Any, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())


yield json.dumps(response_chunk)
yield response_chunk

+ 9
- 2
api/core/app/apps/advanced_chat/generate_task_pipeline.py View File

) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent, QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
QueueAnnotationReplyEvent, QueueAnnotationReplyEvent,
QueueErrorEvent, QueueErrorEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
and features_dict["text_to_speech"].get("enabled") and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled" and features_dict["text_to_speech"].get("autoPlay") == "enabled"
): ):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)


for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
else: else:
start_listener_time = time.time() start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
except Exception:
logger.exception(f"Failed to listen audio message, task_id: {task_id}") logger.exception(f"Failed to listen audio message, task_id: {task_id}")
break break
if tts_publisher: if tts_publisher:
session.commit() session.commit()


yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else: else:
continue continue



+ 12
- 6
api/core/app/apps/agent_chat/app_generator.py View File

import contextvars
import logging import logging
import threading import threading
import uuid import uuid
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
streaming: Literal[False],
) -> Mapping[str, Any]: ...


@overload @overload
def generate( def generate(
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
streaming: Literal[True],
) -> Generator[Mapping | str, None, None]: ...


@overload @overload
def generate( def generate(
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool, streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...


def generate( def generate(
self, self,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
):
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
""" """
Generate App response. Generate App response.


target=self._generate_worker, target=self._generate_worker,
kwargs={ kwargs={
"flask_app": current_app._get_current_object(), # type: ignore "flask_app": current_app._get_current_object(), # type: ignore
"context": contextvars.copy_context(),
"application_generate_entity": application_generate_entity, "application_generate_entity": application_generate_entity,
"queue_manager": queue_manager, "queue_manager": queue_manager,
"conversation_id": conversation.id, "conversation_id": conversation.id,
def _generate_worker( def _generate_worker(
self, self,
flask_app: Flask, flask_app: Flask,
context: contextvars.Context,
application_generate_entity: AgentChatAppGenerateEntity, application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str, conversation_id: str,
:param message_id: message ID :param message_id: message ID
:return: :return:
""" """
for var, val in context.items():
var.set(val)

with flask_app.app_context(): with flask_app.app_context():
try: try:
# get conversation and message # get conversation and message

+ 16
- 99
api/core/app/apps/agent_chat/app_runner.py View File

from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
from models.model import App, Conversation, Message


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


app_record=app_record, app_record=app_record,
model_config=application_generate_entity.model_conf, model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
inputs=dict(inputs),
files=list(files),
query=query, query=query,
) )


app_record=app_record, app_record=app_record,
model_config=application_generate_entity.model_conf, model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
inputs=dict(inputs),
files=list(files),
query=query, query=query,
memory=memory, memory=memory,
) )
app_id=app_record.id, app_id=app_record.id,
tenant_id=app_config.tenant_id, tenant_id=app_config.tenant_id,
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
inputs=dict(inputs),
query=query or "",
message_id=message.id, message_id=message.id,
) )
except ModerationError as e: except ModerationError as e:
app_record=app_record, app_record=app_record,
model_config=application_generate_entity.model_conf, model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
inputs=dict(inputs),
files=list(files),
query=query or "",
memory=memory, memory=memory,
) )


return return


agent_entity = app_config.agent agent_entity = app_config.agent
if not agent_entity:
raise ValueError("Agent entity not found")

# load tool variables
tool_conversation_variables = self._load_tool_variables(
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
)

# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
assert agent_entity is not None


# init model instance # init model instance
model_instance = ModelInstance( model_instance = ModelInstance(
app_record=app_record, app_record=app_record,
model_config=application_generate_entity.model_conf, model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
inputs=dict(inputs),
files=list(files),
query=query or "",
memory=memory, memory=memory,
) )


user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
memory=memory, memory=memory,
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance, model_instance=model_instance,
) )


stream=application_generate_entity.stream, stream=application_generate_entity.stream,
agent=True, agent=True,
) )

def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables | None = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id,
)
.first()
)

if tool_variables:
# save tool variables to session, so that we can update it later
db.session.add(tool_variables)
else:
# create new tool variables
tool_variables = ToolConversationVariables(
conversation_id=conversation_id,
user_id=user_id,
tenant_id=tenant_id,
variables_str="[]",
)
db.session.add(tool_variables)
db.session.commit()

return tool_variables

def _convert_db_variables_to_tool_variables(
self, db_variables: ToolConversationVariables
) -> ToolRuntimeVariablePool:
"""
convert db variables to tool variables
"""
return ToolRuntimeVariablePool(
**{
"conversation_id": db_variables.conversation_id,
"user_id": db_variables.user_id,
"tenant_id": db_variables.tenant_id,
"pool": db_variables.variables,
}
)

def _get_usage_of_all_agent_thoughts(
self, model_config: ModelConfigWithCredentialsEntity, message: Message
) -> LLMUsage:
"""
Get usage of all agent thoughts
:param model_config: model config
:param message: message
:return:
"""
agent_thoughts = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
)

all_message_tokens = 0
all_answer_tokens = 0
for agent_thought in agent_thoughts:
all_message_tokens += agent_thought.message_tokens
all_answer_tokens += agent_thought.answer_tokens

model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

return model_type_instance._calc_response_usage(
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
)

+ 9
- 11
api/core/app/apps/agent_chat/generate_response_converter.py View File

import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import cast


from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
ChatbotAppStreamResponse, ChatbotAppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
return response return response


@classmethod @classmethod
def convert_stream_full_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk


@classmethod @classmethod
def convert_stream_simple_response( # type: ignore[override]
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())


yield json.dumps(response_chunk)
yield response_chunk

+ 8
- 18
api/core/app/apps/base_app_generate_response_converter.py View File



@classmethod @classmethod
def convert( def convert(
cls,
response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]],
invoke_from: InvokeFrom,
) -> Mapping[str, Any] | Generator[str, None, None]:
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response) return cls.convert_blocking_full_response(response)
else: else:


def _generate_full_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_full_response(response):
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f"data: {chunk}\n\n"
def _generate_full_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_full_response(response)


return _generate_full_response() return _generate_full_response()
else: else:
return cls.convert_blocking_simple_response(response) return cls.convert_blocking_simple_response(response)
else: else:


def _generate_simple_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_simple_response(response):
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f"data: {chunk}\n\n"
def _generate_simple_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_simple_response(response)


return _generate_simple_response() return _generate_simple_response()


@abstractmethod @abstractmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
) -> Generator[dict | str, None, None]:
raise NotImplementedError raise NotImplementedError


@classmethod @classmethod
@abstractmethod @abstractmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
) -> Generator[dict | str, None, None]:
raise NotImplementedError raise NotImplementedError


@classmethod @classmethod

+ 21
- 2
api/core/app/apps/base_app_generator.py View File

from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union


from core.app.app_config.entities import VariableEntityType from core.app.app_config.entities import VariableEntityType
from core.file import File, FileUploadConfig from core.file import File, FileUploadConfig
if isinstance(value, str): if isinstance(value, str):
return value.replace("\x00", "") return value.replace("\x00", "")
return value return value

@classmethod
def convert_to_event_stream(cls, generator: Union[Mapping, Generator[Mapping | str, None, None]]):
"""
Convert messages into event stream
"""
if isinstance(generator, dict):
return generator
else:

def gen():
for message in generator:
if isinstance(message, (Mapping, dict)):
yield f"data: {json.dumps(message)}\n\n"
else:
yield f"event: {message}\n\n"

return gen()

+ 2
- 2
api/core/app/apps/base_app_queue_manager.py View File

import time import time
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Any
from typing import Any, Optional


from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm import DeclarativeMeta


Set task stop flag Set task stop flag
:return: :return:
""" """
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id))
if result is None: if result is None:
return return



+ 3
- 3
api/core/app/apps/chat/app_generator.py View File

args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
) -> Generator[str, None, None]: ...
) -> Generator[Mapping | str, None, None]: ...


@overload @overload
def generate( def generate(
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool, streaming: bool,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...


def generate( def generate(
self, self,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
):
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
""" """
Generate App response. Generate App response.



+ 7
- 9
api/core/app/apps/chat/generate_response_converter.py View File

import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import cast


from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
ChatbotAppStreamResponse, ChatbotAppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,


@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk


@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls,
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())


yield json.dumps(response_chunk)
yield response_chunk

+ 5
- 5
api/core/app/apps/completion/app_generator.py View File

args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
) -> Generator[str, None, None]: ...
) -> Generator[str | Mapping[str, Any], None, None]: ...


@overload @overload
def generate( def generate(
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
streaming: bool = False,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...


def generate( def generate(
self, self,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
):
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
""" """
Generate App response. Generate App response.


user: Union[Account, EndUser], user: Union[Account, EndUser],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
""" """
Generate App response. Generate App response.



+ 7
- 9
api/core/app/apps/completion/generate_response_converter.py View File

import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import cast


from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
CompletionAppBlockingResponse, CompletionAppBlockingResponse,
CompletionAppStreamResponse, CompletionAppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,


@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk


@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls,
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())


yield json.dumps(response_chunk)
yield response_chunk

+ 33
- 17
api/core/app/apps/workflow/app_generator.py View File

*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[str, None, None]: ...
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Generator[Mapping | str, None, None]: ...


@overload @overload
def generate( def generate(
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[False], streaming: Literal[False],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any]: ... ) -> Mapping[str, Any]: ...


@overload @overload
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...


def generate( def generate(
self, self,
*, *,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
):
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or [] files: Sequence[Mapping[str, Any]] = args.get("files") or []


# parse files # parse files
trace_manager=trace_manager, trace_manager=trace_manager,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )

contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())


return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.

:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager # init queue manager
queue_manager = WorkflowAppQueueManager( queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id, task_id=application_generate_entity.task_id,
app_model: App, app_model: App,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user: Account,
user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
streaming: bool = True, streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
""" """
Generate App response. Generate App response.


workflow_run_id=str(uuid.uuid4()), workflow_run_id=str(uuid.uuid4()),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())


return self._generate( return self._generate(
app_model=app_model, app_model=app_model,

+ 7
- 9
api/core/app/apps/workflow/generate_response_converter.py View File

import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import cast


from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
NodeFinishStreamResponse, NodeFinishStreamResponse,
NodeStartStreamResponse, NodeStartStreamResponse,


@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
response_chunk.update(data) response_chunk.update(data)
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk


@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls,
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield json.dumps(response_chunk)
yield response_chunk

+ 8
- 1
api/core/app/apps/workflow/generate_task_pipeline.py View File

WorkflowAppGenerateEntity, WorkflowAppGenerateEntity,
) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueErrorEvent, QueueErrorEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
QueueIterationNextEvent, QueueIterationNextEvent,
and features_dict["text_to_speech"].get("enabled") and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled" and features_dict["text_to_speech"].get("autoPlay") == "enabled"
): ):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)


for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
yield self._text_chunk_to_stream_response( yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector delta_text, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else: else:
continue continue



+ 16
- 0
api/core/app/apps/workflow_app_runner.py View File

from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
QueueAgentLogEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
QueueIterationNextEvent, QueueIterationNextEvent,
QueueIterationStartEvent, QueueIterationStartEvent,
from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
GraphEngineEvent, GraphEngineEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent, GraphRunPartialSucceededEvent,
predecessor_node_id=event.predecessor_node_id, predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
parallel_mode_run_id=event.parallel_mode_run_id, parallel_mode_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
) )
) )
elif isinstance(event, NodeRunSucceededEvent): elif isinstance(event, NodeRunSucceededEvent):
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
) )
) )
elif isinstance(event, AgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
)
)
elif isinstance(event, ParallelBranchRunStartedEvent): elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event( self._publish_event(
QueueParallelBranchRunStartedEvent( QueueParallelBranchRunStartedEvent(

+ 1
- 1
api/core/app/entities/app_invoke_entities.py View File

""" """


node_id: str node_id: str
inputs: dict
inputs: Mapping


single_iteration_run: Optional[SingleIterationRunEntity] = None single_iteration_run: Optional[SingleIterationRunEntity] = None



+ 19
- 1
api/core/app/entities/queue_entities.py View File

from pydantic import BaseModel from pydantic import BaseModel


from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base import BaseNodeData
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
AGENT_LOG = "agent_log"
ERROR = "error" ERROR = "error"
PING = "ping" PING = "ping"
STOP = "stop" STOP = "stop"
start_at: datetime start_at: datetime
parallel_mode_run_id: Optional[str] = None parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id""" """iteratoin run in parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None




class QueueNodeSucceededEvent(AppQueueEvent): class QueueNodeSucceededEvent(AppQueueEvent):
iteration_duration_map: Optional[dict[str, float]] = None iteration_duration_map: Optional[dict[str, float]] = None




class QueueAgentLogEvent(AppQueueEvent):
"""
QueueAgentLogEvent entity
"""

event: QueueEvent = QueueEvent.AGENT_LOG
id: str
label: str
node_execution_id: str
parent_id: str | None
error: str | None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None


class QueueNodeRetryEvent(QueueNodeStartedEvent): class QueueNodeRetryEvent(QueueNodeStartedEvent):
"""QueueNodeRetryEvent entity""" """QueueNodeRetryEvent entity"""



+ 26
- 0
api/core/app/entities/task_entities.py View File



from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus




ITERATION_COMPLETED = "iteration_completed" ITERATION_COMPLETED = "iteration_completed"
TEXT_CHUNK = "text_chunk" TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace" TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"




class StreamResponse(BaseModel): class StreamResponse(BaseModel):
parent_parallel_start_node_id: Optional[str] = None parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None iteration_id: Optional[str] = None
parallel_run_id: Optional[str] = None parallel_run_id: Optional[str] = None
agent_strategy: Optional[AgentNodeStrategyInit] = None


event: StreamEvent = StreamEvent.NODE_STARTED event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str workflow_run_id: str


workflow_run_id: str workflow_run_id: str
data: Data data: Data


class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity
"""

class Data(BaseModel):
"""
Data entity
"""

node_execution_id: str
id: str
label: str
parent_id: str | None
error: str | None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None

event: StreamEvent = StreamEvent.AGENT_LOG
data: Data

+ 3
- 1
api/core/app/features/hosting_moderation/hosting_moderation.py View File

if isinstance(prompt_message.content, str): if isinstance(prompt_message.content, str):
text += prompt_message.content + "\n" text += prompt_message.content + "\n"


moderation_result = moderation.check_moderation(model_config, text)
moderation_result = moderation.check_moderation(
tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text
)


return moderation_result return moderation_result

+ 3
- 1
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py View File

and text_to_speech_dict.get("autoPlay") == "enabled" and text_to_speech_dict.get("autoPlay") == "enabled"
and text_to_speech_dict.get("enabled") and text_to_speech_dict.get("enabled")
): ):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
publisher = AppGeneratorTTSPublisher(
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
)
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True: while True:
audio_response = self._listen_audio_msg(publisher, task_id) audio_response = self._listen_audio_msg(publisher, task_id)

+ 26
- 3
api/core/app/task_pipeline/workflow_cycle_manage.py View File



from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
QueueIterationNextEvent, QueueIterationNextEvent,
QueueIterationStartEvent, QueueIterationStartEvent,
QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunSucceededEvent,
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AgentLogStreamResponse,
IterationNodeCompletedStreamResponse, IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse, IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse, IterationNodeStartStreamResponse,
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
execution_metadata_dict = dict(event.execution_metadata or {})
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
finished_at = datetime.now(UTC).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()


parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id, iteration_id=event.in_iteration_id,
parallel_run_id=event.parallel_mode_run_id, parallel_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
), ),
) )


raise ValueError(f"Workflow node execution not found: {node_execution_id}") raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return session.merge(cached_workflow_node_execution) return session.merge(cached_workflow_node_execution)

def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
Handle agent log
:param task_id: task id
:param event: agent log event
:return:
"""
return AgentLogStreamResponse(
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=event.node_execution_id,
id=event.id,
parent_id=event.parent_id,
label=event.label,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
),
)

+ 2
- 2
api/core/callback_handler/agent_tool_callback_handler.py View File

from collections.abc import Mapping, Sequence
from collections.abc import Iterable, Mapping
from typing import Any, Optional, TextIO, Union from typing import Any, Optional, TextIO, Union


from pydantic import BaseModel from pydantic import BaseModel
self, self,
tool_name: str, tool_name: str,
tool_inputs: Mapping[str, Any], tool_inputs: Mapping[str, Any],
tool_outputs: Sequence[ToolInvokeMessage] | str,
tool_outputs: Iterable[ToolInvokeMessage] | str,
message_id: Optional[str] = None, message_id: Optional[str] = None,
timer: Optional[Any] = None, timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None, trace_manager: Optional[TraceQueueManager] = None,

+ 22
- 1
api/core/callback_handler/workflow_tool_callback_handler.py View File

from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from collections.abc import Generator, Iterable, Mapping
from typing import Any, Optional

from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text
from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.entities.tool_entities import ToolInvokeMessage




class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""

def on_tool_execution(
self,
tool_name: str,
tool_inputs: Mapping[str, Any],
tool_outputs: Iterable[ToolInvokeMessage],
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[ToolInvokeMessage, None, None]:
for tool_output in tool_outputs:
print_text("\n[on_tool_execution]\n", color=self.color)
print_text("Tool: " + tool_name + "\n", color=self.color)
print_text("Outputs: " + tool_output.model_dump_json()[:1000] + "\n", color=self.color)
print_text("\n")
yield tool_output

+ 1
- 0
api/core/entities/__init__.py View File

DEFAULT_PLUGIN_ID = "langgenius"

+ 42
- 0
api/core/entities/parameter_entities.py View File

from enum import StrEnum


class CommonParameterType(StrEnum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
STRING = "string"
NUMBER = "number"
FILE = "file"
FILES = "files"
SYSTEM_FILES = "system-files"
BOOLEAN = "boolean"
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"

# TOOL_SELECTOR = "tool-selector"


class AppSelectorScope(StrEnum):
ALL = "all"
CHAT = "chat"
WORKFLOW = "workflow"
COMPLETION = "completion"


class ModelSelectorScope(StrEnum):
LLM = "llm"
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
TTS = "tts"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"


class ToolSelectorScope(StrEnum):
ALL = "all"
CUSTOM = "custom"
BUILTIN = "builtin"
WORKFLOW = "workflow"

+ 120
- 104
api/core/entities/provider_configuration.py View File

import json import json
import logging import logging
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator
from collections.abc import Iterator, Sequence
from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional from typing import Optional


from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict


from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.entities import DEFAULT_PLUGIN_ID
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import ( from core.entities.provider_entities import (
CustomConfiguration, CustomConfiguration,
) )
from core.helper import encrypter from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.provider_entities import ( from core.model_runtime.entities.provider_entities import (
ConfigurateMethod, ConfigurateMethod,
CredentialFormSchema, CredentialFormSchema,
FormType, FormType,
ProviderEntity, ProviderEntity,
) )
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.provider import ( from models.provider import (
LoadBalancingModelConfig, LoadBalancingModelConfig,
continue continue


restrict_models = quota_configuration.restrict_models restrict_models = quota_configuration.restrict_models
if self.system_configuration.credentials is None:
return None
copy_credentials = self.system_configuration.credentials.copy()

copy_credentials = (
self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
)
if restrict_models: if restrict_models:
for restrict_model in restrict_models: for restrict_model in restrict_models:
if ( if (
if current_quota_configuration is None: if current_quota_configuration is None:
return None return None


if not current_quota_configuration:
return SystemConfigurationStatus.UNSUPPORTED

return ( return (
SystemConfigurationStatus.ACTIVE SystemConfigurationStatus.ACTIVE
if current_quota_configuration.is_valid if current_quota_configuration.is_valid
""" """
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0


def get_custom_credentials(self, obfuscated: bool = False):
def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
""" """
Get custom credentials. Get custom credentials.


else [], else [],
) )


def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]:
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
""" """
Validate custom credentials. Validate custom credentials.
:param credentials: provider credentials :param credentials: provider credentials
if value == HIDDEN_VALUE and key in original_credentials: if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])


model_provider_factory = ModelProviderFactory(self.tenant_id)
credentials = model_provider_factory.provider_credentials_validate( credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials provider=self.provider.provider, credentials=credentials
) )
provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(credentials),
is_valid=True,
)
provider_record = Provider()
provider_record.tenant_id = self.tenant_id
provider_record.provider_name = self.provider.provider
provider_record.provider_type = ProviderType.CUSTOM.value
provider_record.encrypted_config = json.dumps(credentials)
provider_record.is_valid = True
db.session.add(provider_record) db.session.add(provider_record)
db.session.commit() db.session.commit()




def custom_model_credentials_validate( def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict self, model_type: ModelType, model: str, credentials: dict
) -> tuple[Optional[ProviderModel], dict]:
) -> tuple[ProviderModel | None, dict]:
""" """
Validate custom model credentials. Validate custom model credentials.


if value == HIDDEN_VALUE and key in original_credentials: if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])


model_provider_factory = ModelProviderFactory(self.tenant_id)
credentials = model_provider_factory.model_credentials_validate( credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
) )
provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials),
is_valid=True,
)
provider_model_record = ProviderModel()
provider_model_record.tenant_id = self.tenant_id
provider_model_record.provider_name = self.provider.provider
provider_model_record.model_name = model
provider_model_record.model_type = model_type.to_origin_model_type()
provider_model_record.encrypted_config = json.dumps(credentials)
provider_model_record.is_valid = True
db.session.add(provider_model_record) db.session.add(provider_model_record)
db.session.commit() db.session.commit()


model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=True,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.enabled = True
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()


model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=False,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.enabled = False
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()


model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=True,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.load_balancing_enabled = True
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()


model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=False,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.load_balancing_enabled = False
db.session.add(model_setting) db.session.add(model_setting)
db.session.commit() db.session.commit()


return model_setting return model_setting


def get_provider_instance(self) -> ModelProvider:
"""
Get provider instance.
:return:
"""
return model_provider_factory.get_provider_instance(self.provider.provider)

def get_model_type_instance(self, model_type: ModelType) -> AIModel: def get_model_type_instance(self, model_type: ModelType) -> AIModel:
""" """
Get current model type instance. Get current model type instance.
:param model_type: model type :param model_type: model type
:return: :return:
""" """
# Get provider instance
provider_instance = self.get_provider_instance()
model_provider_factory = ModelProviderFactory(self.tenant_id)


# Get model instance of LLM # Get model instance of LLM
return provider_instance.get_model_instance(model_type)
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)

def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get model schema
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
return model_provider_factory.get_model_schema(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)


def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
""" """
if preferred_model_provider: if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value preferred_model_provider.preferred_provider_type = provider_type.value
else: else:
preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
preferred_provider_type=provider_type.value,
)
preferred_model_provider = TenantPreferredModelProvider()
preferred_model_provider.tenant_id = self.tenant_id
preferred_model_provider.provider_name = self.provider.provider
preferred_model_provider.preferred_provider_type = provider_type.value
db.session.add(preferred_model_provider) db.session.add(preferred_model_provider)


db.session.commit() db.session.commit()
:param only_active: only active models :param only_active: only active models
:return: :return:
""" """
provider_instance = self.get_provider_instance()
model_provider_factory = ModelProviderFactory(self.tenant_id)
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)


model_types = []
model_types: list[ModelType] = []
if model_type: if model_type:
model_types.append(model_type) model_types.append(model_type)
else: else:
model_types = list(provider_instance.get_provider_schema().supported_model_types)
model_types = list(provider_schema.supported_model_types)


# Group model settings by model type and model # Group model settings by model type and model
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)


if self.using_provider_type == ProviderType.SYSTEM: if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models( provider_models = self._get_system_provider_models(
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
) )
else: else:
provider_models = self._get_custom_provider_models( provider_models = self._get_custom_provider_models(
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
) )


if only_active: if only_active:


def _get_system_provider_models( def _get_system_provider_models(
self, self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]], model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]: ) -> list[ModelWithProviderEntity]:
""" """
Get system provider models. Get system provider models.


:param model_types: model types :param model_types: model types
:param provider_instance: provider instance
:param provider_schema: provider schema
:param model_setting_map: model setting map :param model_setting_map: model setting map
:return: :return:
""" """
provider_models = [] provider_models = []
for model_type in model_types: for model_type in model_types:
for m in provider_instance.models(model_type):
for m in provider_schema.models:
if m.model_type != model_type:
continue

status = ModelStatus.ACTIVE status = ModelStatus.ACTIVE
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
if m.model in model_setting_map:
model_setting = model_setting_map[m.model_type][m.model] model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False: if model_setting.enabled is False:
status = ModelStatus.DISABLED status = ModelStatus.DISABLED


if self.provider.provider not in original_provider_configurate_methods: if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = [] original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
for configurate_method in provider_schema.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method) original_provider_configurate_methods[self.provider.provider].append(configurate_method)


should_use_custom_model = False should_use_custom_model = False
]: ]:
# only customizable model # only customizable model
for restrict_model in restrict_models: for restrict_model in restrict_models:
if self.system_configuration.credentials is not None:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials["base_model_name"] = restrict_model.base_model_name

try:
custom_model_schema = provider_instance.get_model_instance(
restrict_model.model_type
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")
continue
copy_credentials = (
self.system_configuration.credentials.copy()
if self.system_configuration.credentials
else {}
)
if restrict_model.base_model_name:
copy_credentials["base_model_name"] = restrict_model.base_model_name

try:
custom_model_schema = self.get_model_schema(
model_type=restrict_model.model_type,
model=restrict_model.model,
credentials=copy_credentials,
)
except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")


if not custom_model_schema: if not custom_model_schema:
continue continue


def _get_custom_provider_models( def _get_custom_provider_models(
self, self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]], model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]: ) -> list[ModelWithProviderEntity]:
""" """
Get custom provider models. Get custom provider models.


:param model_types: model types :param model_types: model types
:param provider_instance: provider instance
:param provider_schema: provider schema
:param model_setting_map: model setting map :param model_setting_map: model setting map
:return: :return:
""" """
if model_type not in self.provider.supported_model_types: if model_type not in self.provider.supported_model_types:
continue continue


models = provider_instance.models(model_type)
for m in models:
for m in provider_schema.models:
if m.model_type != model_type:
continue

status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
load_balancing_enabled = False load_balancing_enabled = False
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
continue continue


try: try:
custom_model_schema = provider_instance.get_model_instance(
model_configuration.model_type
).get_customizable_model_schema_from_credentials(
model_configuration.model, model_configuration.credentials
custom_model_schema = self.get_model_schema(
model_type=model_configuration.model_type,
model=model_configuration.model,
credentials=model_configuration.credentials,
) )
except Exception as ex: except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}") logger.warning(f"get custom model schema failed, {ex}")
label=custom_model_schema.label, label=custom_model_schema.label,
model_type=custom_model_schema.model_type, model_type=custom_model_schema.model_type,
features=custom_model_schema.features, features=custom_model_schema.features,
fetch_from=custom_model_schema.fetch_from,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties=custom_model_schema.model_properties, model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated, deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
return list(self.values()) return list(self.values())


def __getitem__(self, key): def __getitem__(self, key):
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"

return self.configurations[key] return self.configurations[key]


def __setitem__(self, key, value): def __setitem__(self, key, value):
def values(self) -> Iterator[ProviderConfiguration]: def values(self) -> Iterator[ProviderConfiguration]:
return iter(self.configurations.values()) return iter(self.configurations.values())


def get(self, key, default=None):
return self.configurations.get(key, default)
def get(self, key, default=None) -> ProviderConfiguration | None:
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"

return self.configurations.get(key, default) # type: ignore




class ProviderModelBundle(BaseModel): class ProviderModelBundle(BaseModel):
""" """


configuration: ProviderConfiguration configuration: ProviderConfiguration
provider_instance: ModelProvider
model_type_instance: AIModel model_type_instance: AIModel


# pydantic configs # pydantic configs

+ 79
- 3
api/core/entities/provider_entities.py View File

from enum import Enum from enum import Enum
from typing import Optional
from typing import Optional, Union


from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field


from core.entities.parameter_entities import (
AppSelectorScope,
CommonParameterType,
ModelSelectorScope,
ToolSelectorScope,
)
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType
from core.tools.entities.common_entities import I18nObject


class ProviderQuotaType(Enum):
PAID = "paid"
"""hosted paid quota"""

FREE = "free"
"""third-party free quota"""

TRIAL = "trial"
"""hosted trial quota"""

@staticmethod
def value_of(value):
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")




class QuotaUnit(Enum): class QuotaUnit(Enum):


# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())


class BasicProviderConfig(BaseModel):
"""
Base model class for common provider settings like credentials
"""

class Type(Enum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
SELECT = CommonParameterType.SELECT.value
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value

@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":
"""
Get value of given mode.

:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")

type: Type = Field(..., description="The type of the credentials")
name: str = Field(..., description="The name of the credentials")


class ProviderConfig(BasicProviderConfig):
"""
Model class for common provider settings like credentials
"""

class Option(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")

scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
required: bool = False
default: Optional[Union[int, str]] = None
options: Optional[list[Option]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
url: Optional[str] = None
placeholder: Optional[I18nObject] = None

def to_basic_provider_config(self) -> BasicProviderConfig:
return BasicProviderConfig(type=self.type, name=self.name)

+ 35
- 0
api/core/file/helpers.py View File

return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"




def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str:
url = f"{dify_config.FILES_URL}/files/upload/for-plugin"

if user_id is None:
user_id = "DEFAULT-USER"

timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
key = dify_config.SECRET_KEY.encode()
msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()

return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}"


def verify_plugin_file_signature(
*, filename: str, mimetype: str, tenant_id: str, user_id: str | None, timestamp: str, nonce: str, sign: str
) -> bool:
if user_id is None:
user_id = "DEFAULT-USER"

data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()

# verify signature
if sign != recalculated_encoded_sign:
return False

current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT


def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() secret_key = dify_config.SECRET_KEY.encode()

+ 12
- 1
api/core/file/models.py View File

from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Optional
from typing import Any, Optional


from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator


tool_file_id=self.related_id, extension=self.extension tool_file_id=self.related_id, extension=self.extension
) )


def to_plugin_parameter(self) -> dict[str, Any]:
return {
"dify_model_identity": FILE_MODEL_IDENTITY,
"mime_type": self.mime_type,
"filename": self.filename,
"extension": self.extension,
"size": self.size,
"type": self.type,
"url": self.generate_url(),
}

@model_validator(mode="after") @model_validator(mode="after")
def validate_after(self): def validate_after(self):
match self.transfer_method: match self.transfer_method:

+ 69
- 0
api/core/file/upload_file_parser.py View File

import base64
import logging
import time
from typing import Optional

from configs import dify_config
from core.helper.url_signer import UrlSigner
from extensions.ext_storage import storage

IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])


class UploadFileParser:
@classmethod
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
if not upload_file:
return None

if upload_file.extension not in IMAGE_EXTENSIONS:
return None

if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.exception(f"File not found: {upload_file.key}")
return None

encoded_string = base64.b64encode(data).decode("utf-8")
return f"data:{upload_file.mime_type};base64,{encoded_string}"

@classmethod
def get_signed_temp_image_url(cls, upload_file_id) -> str:
"""
get signed url from upload file

:param upload_file: UploadFile object
:return:
"""
base_url = dify_config.FILES_URL
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"

return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")

@classmethod
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature

:param upload_file_id: file id
:param timestamp: timestamp
:param nonce: nonce
:param sign: signature
:return:
"""
result = UrlSigner.verify(
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
)

# verify signature
if not result:
return False

current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

+ 17
- 0
api/core/helper/download.py View File

from core.helper import ssrf_proxy


def download_with_size_limit(url, max_download_size: int, **kwargs):
response = ssrf_proxy.get(url, follow_redirects=True, **kwargs)
if response.status_code == 404:
raise ValueError("file not found")

total_size = 0
chunks = []
for chunk in response.iter_bytes():
total_size += len(chunk)
if total_size > max_download_size:
raise ValueError("Max file size reached")
chunks.append(chunk)
content = b"".join(chunks)
return content

+ 35
- 0
api/core/helper/marketplace.py View File

from collections.abc import Sequence

import requests
from yarl import URL

from configs import dify_config
from core.helper.download import download_with_size_limit
from core.plugin.entities.marketplace import MarketplacePluginDeclaration


def get_plugin_pkg_url(plugin_unique_identifier: str):
return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query(
unique_identifier=plugin_unique_identifier
)


def download_plugin_pkg(plugin_unique_identifier: str):
url = str(get_plugin_pkg_url(plugin_unique_identifier))
return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE)


def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
if len(plugin_ids) == 0:
return []

url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]


def record_install_plugin_event(plugin_unique_identifier: str):
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
response.raise_for_status()

+ 22
- 10
api/core/helper/moderation.py View File

import logging import logging
import random import random
from typing import cast


from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_hosting_provider import hosting_configuration from extensions.ext_hosting_provider import hosting_configuration
from models.provider import ProviderType from models.provider import ProviderType


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)




def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
moderation_config = hosting_configuration.moderation_config moderation_config = hosting_configuration.moderation_config
openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai"
if ( if (
moderation_config moderation_config
and moderation_config.enabled is True and moderation_config.enabled is True
and "openai" in hosting_configuration.provider_map
and hosting_configuration.provider_map["openai"].enabled is True
and openai_provider_name in hosting_configuration.provider_map
and hosting_configuration.provider_map[openai_provider_name].enabled is True
): ):
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
provider_name = model_config.provider provider_name = model_config.provider
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
hosting_openai_config = hosting_configuration.provider_map["openai"]
assert hosting_openai_config is not None
hosting_openai_config = hosting_configuration.provider_map[openai_provider_name]

if hosting_openai_config.credentials is None:
return False


# 2000 text per chunk # 2000 text per chunk
length = 2000 length = 2000
text_chunk = random.choice(text_chunks) text_chunk = random.choice(text_chunks)


try: try:
model_type_instance = OpenAIModerationModel()
# FIXME, for type hint using assert or raise ValueError is better here?
model_provider_factory = ModelProviderFactory(tenant_id)

# Get model instance of LLM
model_type_instance = model_provider_factory.get_model_type_instance(
provider=openai_provider_name, model_type=ModelType.MODERATION
)
model_type_instance = cast(ModerationModel, model_type_instance)
moderation_result = model_type_instance.invoke( moderation_result = model_type_instance.invoke(
model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk
model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk
) )


if moderation_result is True: if moderation_result is True:
return True return True
except Exception as ex:
except Exception:
logger.exception(f"Fails to check moderation, provider_name: {provider_name}") logger.exception(f"Fails to check moderation, provider_name: {provider_name}")
raise InvokeBadRequestError("Rate limit exceeded, please try again later.") raise InvokeBadRequestError("Rate limit exceeded, please try again later.")



+ 0
- 1
api/core/helper/ssrf_proxy.py View File

) )


retries = 0 retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries: while retries <= max_retries:
try: try:
if dify_config.SSRF_PROXY_ALL_URL: if dify_config.SSRF_PROXY_ALL_URL:

+ 1
- 0
api/core/helper/tool_provider_cache.py View File



class ToolProviderCredentialsCacheType(Enum): class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider" PROVIDER = "tool_provider"
ENDPOINT = "endpoint"




class ToolProviderCredentialsCache: class ToolProviderCredentialsCache:

+ 52
- 0
api/core/helper/url_signer.py View File

import base64
import hashlib
import hmac
import os
import time

from pydantic import BaseModel, Field

from configs import dify_config


class SignedUrlParams(BaseModel):
sign_key: str = Field(..., description="The sign key")
timestamp: str = Field(..., description="Timestamp")
nonce: str = Field(..., description="Nonce")
sign: str = Field(..., description="Signature")


class UrlSigner:
@classmethod
def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str:
signed_url_params = cls.get_signed_url_params(sign_key, prefix)
return (
f"{url}?timestamp={signed_url_params.timestamp}"
f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}"
)

@classmethod
def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams:
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
sign = cls._sign(sign_key, timestamp, nonce, prefix)

return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign)

@classmethod
def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool:
recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix)

return sign == recalculated_sign

@classmethod
def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str:
if not dify_config.SECRET_KEY:
raise Exception("SECRET_KEY is not set")

data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()

return encoded_sign

+ 16
- 9
api/core/hosting_configuration.py View File

from pydantic import BaseModel from pydantic import BaseModel


from configs import dify_config from configs import dify_config
from core.entities.provider_entities import QuotaUnit, RestrictModel
from core.entities import DEFAULT_PLUGIN_ID
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType




class HostingQuota(BaseModel): class HostingQuota(BaseModel):
if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
return return


self.provider_map["azure_openai"] = self.init_azure_openai()
self.provider_map["openai"] = self.init_openai()
self.provider_map["anthropic"] = self.init_anthropic()
self.provider_map["minimax"] = self.init_minimax()
self.provider_map["spark"] = self.init_spark()
self.provider_map["zhipuai"] = self.init_zhipuai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()


self.moderation_config = self.init_moderation_config() self.moderation_config = self.init_moderation_config()


@staticmethod @staticmethod
def init_moderation_config() -> HostedModerationConfig: def init_moderation_config() -> HostedModerationConfig:
if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS: if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS:
return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(","))
providers = dify_config.HOSTED_MODERATION_PROVIDERS.split(",")
hosted_providers = []
for provider in providers:
if "/" not in provider:
provider = f"{DEFAULT_PLUGIN_ID}/{provider}/{provider}"
hosted_providers.append(provider)

return HostedModerationConfig(enabled=True, providers=hosted_providers)


return HostedModerationConfig(enabled=False) return HostedModerationConfig(enabled=False)



+ 3
- 5
api/core/indexing_runner.py View File

FixedRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter,
) )
from core.rag.splitter.text_splitter import TextSplitter from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage


tokens = 0 tokens = 0
if embedding_model_instance: if embedding_model_instance:
tokens += sum(
embedding_model_instance.get_text_embedding_num_tokens([document.page_content])
for document in chunk_documents
)
page_content_list = [document.page_content for document in chunk_documents]
tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))


# load index # load index
index_processor.load(dataset, chunk_documents, with_keywords=False) index_processor.load(dataset, chunk_documents, with_keywords=False)

+ 8
- 8
api/core/llm_generator/llm_generator.py View File

response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
), ),
) )
answer = cast(str, response.message.content) answer = cast(str, response.message.content)
response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages,
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0}, model_parameters={"max_tokens": 256, "temperature": 0},
stream=False, stream=False,
), ),
questions = output_parser.parse(cast(str, response.message.content)) questions = output_parser.parse(cast(str, response.message.content))
except InvokeError: except InvokeError:
questions = [] questions = []
except Exception as e:
except Exception:
logging.exception("Failed to generate suggested questions after answer") logging.exception("Failed to generate suggested questions after answer")
questions = [] questions = []


response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
), ),
) )


prompt_content = cast( prompt_content = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
), ),
) )
except InvokeError as e: except InvokeError as e:
parameter_content = cast( parameter_content = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
), ),
) )
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
statement_content = cast( statement_content = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
), ),
) )
rule_config["opening_statement"] = cast(str, statement_content.message.content) rule_config["opening_statement"] = cast(str, statement_content.message.content)
response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
), ),
) )



+ 40
- 4
api/core/model_manager.py View File

import logging import logging
from collections.abc import Callable, Generator, Iterable, Sequence from collections.abc import Callable, Generator, Iterable, Sequence
from typing import IO, Any, Optional, Union, cast
from typing import IO, Any, Literal, Optional, Union, cast, overload


from configs import dify_config from configs import dify_config
from core.entities.embedding_type import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType


return None return None


@overload
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: Literal[True] = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Generator: ...

@overload
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: Literal[False] = False,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> LLMResult: ...

@overload
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]: ...

def invoke_llm( def invoke_llm(
self, self,
prompt_messages: Sequence[PromptMessage], prompt_messages: Sequence[PromptMessage],
), ),
) )


def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
""" """
Get number of tokens for text embedding Get number of tokens for text embedding




self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast( return cast(
int,
list[int],
self._round_robin_invoke( self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens, function=self.model_type_instance.get_num_tokens,
model=self.model, model=self.model,


return ModelInstance(provider_model_bundle, model) return ModelInstance(provider_model_bundle, model)


def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
""" """
Return first provider and the first model in the provider Return first provider and the first model in the provider
:param tenant_id: tenant id :param tenant_id: tenant id

+ 0
- 0
api/core/model_runtime/entities/model_entities.py View File


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save