Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: xhe <xw897002528@gmail.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: takatost <takatost@gmail.com> Co-authored-by: kurokobo <kuro664@gmail.com> Co-authored-by: Novice Lee <novicelee@NoviPro.local> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: AkaraChen <akarachen@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com> Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com> Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Novice <857526207@qq.com> Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com> Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com> Co-authored-by: eux <euxuuu@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: lotsik <lotsik@mail.ru> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com> Co-authored-by: CN-P5 <heibai2006@gmail.com> Co-authored-by: CN-P5 <heibai2006@qq.com> Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Boris Feld <lothiraldan@gmail.com> Co-authored-by: mbo <himabo@gmail.com> Co-authored-by: mabo <mabo@aeyes.ai> Co-authored-by: Warren Chen <warren.chen830@gmail.com> Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com> Co-authored-by: jiandanfeng <chenjh3@wangsu.com> Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com> Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com> Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com> Co-authored-by: Xu Song <xusong.vip@gmail.com> Co-authored-by: rayshaw001 <396301947@163.com> Co-authored-by: Ding Jiatong <dingjiatong@gmail.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: JasonVV <jasonwangiii@outlook.com> Co-authored-by: le0zh <newlight@qq.com> Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com> Co-authored-by: k-zaku <zaku99@outlook.jp> Co-authored-by: luckylhb90 <luckylhb90@gmail.com> Co-authored-by: hobo.l <hobo.l@binance.com> Co-authored-by: jiangbo721 <365065261@qq.com> Co-authored-by: 刘江波 <jiangbo721@163.com> Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com> Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: sino <sino2322@gmail.com> Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com> Co-authored-by: lowell <lowell.hu@zkteco.in> Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com> Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com> Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com> Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com> Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com> Co-authored-by: Jason <ggbbddjm@gmail.com> Co-authored-by: Xin Zhang <sjhpzx@gmail.com> Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com> Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com> Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com> Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com> Co-authored-by: Yingchun Lai <laiyingchun@apache.org> Co-authored-by: Hash Brown <hi@xzd.me> Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com> Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com> Co-authored-by: aplio <ryo.091219@gmail.com> Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com> Co-authored-by: Nam Vu <zuzoovn@gmail.com> Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com> Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com> Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com> Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com> Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp> Co-authored-by: HQidea <HQidea@users.noreply.github.com> Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com> Co-authored-by: xhe <xw897002528@gmail.com> Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com> Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com> Co-authored-by: engchina <12236799+engchina@users.noreply.github.com> Co-authored-by: engchina <atjapan2015@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: 呆萌闷油瓶 <253605712@qq.com> Co-authored-by: Kemal <kemalmeler@outlook.com> Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com> Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com> Co-authored-by: steven <sunzwj@digitalchina.com> Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com> Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com> Co-authored-by: 胡春东 <gycm520@gmail.com> Co-authored-by: Junjie.M <118170653@qq.com> Co-authored-by: MuYu <mr.muzea@gmail.com> Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com> Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com> Co-authored-by: Fei He <droxer.he@gmail.com> Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: douxc <7553076+douxc@users.noreply.github.com> Co-authored-by: liuzhenghua <1090179900@qq.com> Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com> Co-authored-by: AugNSo <song.tiankai@icloud.com> Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com> Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com> Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com> Co-authored-by: Hundredwz <1808096180@qq.com> Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>tags/1.0.0
| @@ -1,11 +1,12 @@ | |||
| #!/bin/bash | |||
| cd web && npm install | |||
| npm add -g pnpm@9.12.2 | |||
| cd web && pnpm install | |||
| pipx install poetry | |||
| echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc | |||
| echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc | |||
| echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc | |||
| echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc | |||
| echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc | |||
| echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc | |||
| @@ -50,15 +50,9 @@ jobs: | |||
| - name: Run Unit tests | |||
| run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh | |||
| - name: Run ModelRuntime | |||
| run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh | |||
| - name: Run dify config tests | |||
| run: poetry run -P api python dev/pytest/pytest_config_tests.py | |||
| - name: Run Tool | |||
| run: poetry run -P api bash dev/pytest/pytest_tools.sh | |||
| - name: Run mypy | |||
| run: | | |||
| poetry run -C api python -m mypy --install-types --non-interactive . | |||
| @@ -4,6 +4,7 @@ on: | |||
| pull_request: | |||
| branches: | |||
| - main | |||
| - plugins/beta | |||
| paths: | |||
| - api/migrations/** | |||
| - .github/workflows/db-migration-test.yml | |||
| @@ -72,17 +72,23 @@ jobs: | |||
| with: | |||
| files: web/** | |||
| - name: Install pnpm | |||
| uses: pnpm/action-setup@v4 | |||
| with: | |||
| version: 10 | |||
| run_install: false | |||
| - name: Setup NodeJS | |||
| uses: actions/setup-node@v4 | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| with: | |||
| node-version: 20 | |||
| cache: yarn | |||
| cache: pnpm | |||
| cache-dependency-path: ./web/package.json | |||
| - name: Web dependencies | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| run: yarn install --frozen-lockfile | |||
| run: pnpm install --frozen-lockfile | |||
| - name: Web style check | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| @@ -35,10 +35,10 @@ jobs: | |||
| with: | |||
| node-version: ${{ matrix.node-version }} | |||
| cache: '' | |||
| cache-dependency-path: 'yarn.lock' | |||
| cache-dependency-path: 'pnpm-lock.yaml' | |||
| - name: Install Dependencies | |||
| run: yarn install | |||
| run: pnpm install | |||
| - name: Test | |||
| run: yarn test | |||
| run: pnpm test | |||
| @@ -39,11 +39,11 @@ jobs: | |||
| - name: Install dependencies | |||
| if: env.FILES_CHANGED == 'true' | |||
| run: yarn install --frozen-lockfile | |||
| run: pnpm install --frozen-lockfile | |||
| - name: Run npm script | |||
| if: env.FILES_CHANGED == 'true' | |||
| run: npm run auto-gen-i18n | |||
| run: pnpm run auto-gen-i18n | |||
| - name: Create Pull Request | |||
| if: env.FILES_CHANGED == 'true' | |||
| @@ -37,13 +37,13 @@ jobs: | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| with: | |||
| node-version: 20 | |||
| cache: yarn | |||
| cache: pnpm | |||
| cache-dependency-path: ./web/package.json | |||
| - name: Install dependencies | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| run: yarn install --frozen-lockfile | |||
| run: pnpm install --frozen-lockfile | |||
| - name: Run tests | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| run: yarn test | |||
| run: pnpm test | |||
| @@ -176,6 +176,7 @@ docker/volumes/pgvector/data/* | |||
| docker/volumes/pgvecto_rs/data/* | |||
| docker/volumes/couchbase/* | |||
| docker/volumes/oceanbase/* | |||
| docker/volumes/plugin_daemon/* | |||
| !docker/volumes/oceanbase/init.d | |||
| docker/nginx/conf.d/default.conf | |||
| @@ -194,3 +195,9 @@ api/.vscode | |||
| .idea/ | |||
| .vscode | |||
| # pnpm | |||
| /.pnpm-store | |||
| # plugin migrate | |||
| plugins.jsonl | |||
| @@ -1,7 +1,10 @@ | |||
| .env | |||
| *.env.* | |||
| storage/generate_files/* | |||
| storage/privkeys/* | |||
| storage/tools/* | |||
| storage/upload_files/* | |||
| # Logs | |||
| logs | |||
| @@ -9,6 +12,8 @@ logs | |||
| # jetbrains | |||
| .idea | |||
| .mypy_cache | |||
| .ruff_cache | |||
| # venv | |||
| .venv | |||
| @@ -409,7 +409,6 @@ MAX_VARIABLE_SIZE=204800 | |||
| APP_MAX_EXECUTION_TIME=1200 | |||
| APP_MAX_ACTIVE_REQUESTS=0 | |||
| # Celery beat configuration | |||
| CELERY_BEAT_SCHEDULER_TIME=1 | |||
| @@ -422,6 +421,22 @@ POSITION_PROVIDER_PINS= | |||
| POSITION_PROVIDER_INCLUDES= | |||
| POSITION_PROVIDER_EXCLUDES= | |||
| # Plugin configuration | |||
| PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi | |||
| PLUGIN_DAEMON_URL=http://127.0.0.1:5002 | |||
| PLUGIN_REMOTE_INSTALL_PORT=5003 | |||
| PLUGIN_REMOTE_INSTALL_HOST=localhost | |||
| PLUGIN_MAX_PACKAGE_SIZE=15728640 | |||
| INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 | |||
| INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 | |||
| # Marketplace configuration | |||
| MARKETPLACE_ENABLED=true | |||
| MARKETPLACE_API_URL=https://marketplace.dify.ai | |||
| # Endpoint configuration | |||
| ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} | |||
| # Reset password token expiry minutes | |||
| RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 | |||
| @@ -73,6 +73,10 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" | |||
| # Download nltk data | |||
| RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')" | |||
| ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache | |||
| RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" | |||
| # Copy source code | |||
| COPY . /app/api/ | |||
| @@ -25,6 +25,8 @@ from models.dataset import Document as DatasetDocument | |||
| from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation | |||
| from models.provider import Provider, ProviderModel | |||
| from services.account_service import RegisterService, TenantService | |||
| from services.plugin.data_migration import PluginDataMigration | |||
| from services.plugin.plugin_migration import PluginMigration | |||
| @click.command("reset-password", help="Reset the account password.") | |||
| @@ -524,7 +526,7 @@ def add_qdrant_doc_id_index(field: str): | |||
| ) | |||
| ) | |||
| except Exception as e: | |||
| except Exception: | |||
| click.echo(click.style("Failed to create Qdrant client.", fg="red")) | |||
| click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green")) | |||
| @@ -593,7 +595,7 @@ def upgrade_db(): | |||
| click.echo(click.style("Database migration successful!", fg="green")) | |||
| except Exception as e: | |||
| except Exception: | |||
| logging.exception("Failed to execute database migration") | |||
| finally: | |||
| lock.release() | |||
| @@ -639,7 +641,7 @@ where sites.id is null limit 1000""" | |||
| account = accounts[0] | |||
| print("Fixing missing site for app {}".format(app.id)) | |||
| app_was_created.send(app, account=account) | |||
| except Exception as e: | |||
| except Exception: | |||
| failed_app_ids.append(app_id) | |||
| click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red")) | |||
| logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}") | |||
| @@ -649,3 +651,68 @@ where sites.id is null limit 1000""" | |||
| break | |||
| click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) | |||
| @click.command("migrate-data-for-plugin", help="Migrate data for plugin.") | |||
| def migrate_data_for_plugin(): | |||
| """ | |||
| Migrate data for plugin. | |||
| """ | |||
| click.echo(click.style("Starting migrate data for plugin.", fg="white")) | |||
| PluginDataMigration.migrate() | |||
| click.echo(click.style("Migrate data for plugin completed.", fg="green")) | |||
| @click.command("extract-plugins", help="Extract plugins.") | |||
| @click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl") | |||
| @click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10) | |||
| def extract_plugins(output_file: str, workers: int): | |||
| """ | |||
| Extract plugins. | |||
| """ | |||
| click.echo(click.style("Starting extract plugins.", fg="white")) | |||
| PluginMigration.extract_plugins(output_file, workers) | |||
| click.echo(click.style("Extract plugins completed.", fg="green")) | |||
| @click.command("extract-unique-identifiers", help="Extract unique identifiers.") | |||
| @click.option( | |||
| "--output_file", | |||
| prompt=True, | |||
| help="The file to store the extracted unique identifiers.", | |||
| default="unique_identifiers.json", | |||
| ) | |||
| @click.option( | |||
| "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" | |||
| ) | |||
| def extract_unique_plugins(output_file: str, input_file: str): | |||
| """ | |||
| Extract unique plugins. | |||
| """ | |||
| click.echo(click.style("Starting extract unique plugins.", fg="white")) | |||
| PluginMigration.extract_unique_plugins_to_file(input_file, output_file) | |||
| click.echo(click.style("Extract unique plugins completed.", fg="green")) | |||
| @click.command("install-plugins", help="Install plugins.") | |||
| @click.option( | |||
| "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" | |||
| ) | |||
| @click.option( | |||
| "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" | |||
| ) | |||
| def install_plugins(input_file: str, output_file: str): | |||
| """ | |||
| Install plugins. | |||
| """ | |||
| click.echo(click.style("Starting install plugins.", fg="white")) | |||
| PluginMigration.install_plugins(input_file, output_file) | |||
| click.echo(click.style("Install plugins completed.", fg="green")) | |||
| @@ -134,6 +134,60 @@ class CodeExecutionSandboxConfig(BaseSettings): | |||
| ) | |||
| class PluginConfig(BaseSettings): | |||
| """ | |||
| Plugin configs | |||
| """ | |||
| PLUGIN_DAEMON_URL: HttpUrl = Field( | |||
| description="Plugin API URL", | |||
| default="http://localhost:5002", | |||
| ) | |||
| PLUGIN_DAEMON_KEY: str = Field( | |||
| description="Plugin API key", | |||
| default="plugin-api-key", | |||
| ) | |||
| INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") | |||
| PLUGIN_REMOTE_INSTALL_HOST: str = Field( | |||
| description="Plugin Remote Install Host", | |||
| default="localhost", | |||
| ) | |||
| PLUGIN_REMOTE_INSTALL_PORT: PositiveInt = Field( | |||
| description="Plugin Remote Install Port", | |||
| default=5003, | |||
| ) | |||
| PLUGIN_MAX_PACKAGE_SIZE: PositiveInt = Field( | |||
| description="Maximum allowed size for plugin packages in bytes", | |||
| default=15728640, | |||
| ) | |||
| PLUGIN_MAX_BUNDLE_SIZE: PositiveInt = Field( | |||
| description="Maximum allowed size for plugin bundles in bytes", | |||
| default=15728640 * 12, | |||
| ) | |||
| class MarketplaceConfig(BaseSettings): | |||
| """ | |||
| Configuration for marketplace | |||
| """ | |||
| MARKETPLACE_ENABLED: bool = Field( | |||
| description="Enable or disable marketplace", | |||
| default=True, | |||
| ) | |||
| MARKETPLACE_API_URL: HttpUrl = Field( | |||
| description="Marketplace API URL", | |||
| default="https://marketplace.dify.ai", | |||
| ) | |||
| class EndpointConfig(BaseSettings): | |||
| """ | |||
| Configuration for various application endpoints and URLs | |||
| @@ -160,6 +214,10 @@ class EndpointConfig(BaseSettings): | |||
| default="", | |||
| ) | |||
| ENDPOINT_URL_TEMPLATE: str = Field( | |||
| description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}" | |||
| ) | |||
| class FileAccessConfig(BaseSettings): | |||
| """ | |||
| @@ -793,6 +851,8 @@ class FeatureConfig( | |||
| AuthConfig, # Changed from OAuthConfig to AuthConfig | |||
| BillingConfig, | |||
| CodeExecutionSandboxConfig, | |||
| PluginConfig, | |||
| MarketplaceConfig, | |||
| DataSetConfig, | |||
| EndpointConfig, | |||
| FileAccessConfig, | |||
| @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): | |||
| CURRENT_VERSION: str = Field( | |||
| description="Dify version", | |||
| default="0.15.3", | |||
| default="1.0.0", | |||
| ) | |||
| COMMIT_SHA: str = Field( | |||
| @@ -1,9 +1,19 @@ | |||
| from contextvars import ContextVar | |||
| from threading import Lock | |||
| from typing import TYPE_CHECKING | |||
| if TYPE_CHECKING: | |||
| from core.plugin.entities.plugin_daemon import PluginModelProviderEntity | |||
| from core.tools.plugin_tool.provider import PluginToolProviderController | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| tenant_id: ContextVar[str] = ContextVar("tenant_id") | |||
| workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") | |||
| plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers") | |||
| plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock") | |||
| plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers") | |||
| plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock") | |||
| @@ -2,7 +2,7 @@ from flask import Blueprint | |||
| from libs.external_api import ExternalApi | |||
| from .app.app_import import AppImportApi, AppImportConfirmApi | |||
| from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi | |||
| from .explore.audio import ChatAudioApi, ChatTextApi | |||
| from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi | |||
| from .explore.conversation import ( | |||
| @@ -40,6 +40,7 @@ api.add_resource(RemoteFileUploadApi, "/remote-files/upload") | |||
| # Import App | |||
| api.add_resource(AppImportApi, "/apps/imports") | |||
| api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm") | |||
| api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies") | |||
| # Import other controllers | |||
| from . import admin, apikey, extension, feature, ping, setup, version | |||
| @@ -166,4 +167,15 @@ api.add_resource( | |||
| from .tag import tags | |||
| # Import workspace controllers | |||
| from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace | |||
| from .workspace import ( | |||
| account, | |||
| agent_providers, | |||
| endpoint, | |||
| load_balancing_config, | |||
| members, | |||
| model_providers, | |||
| models, | |||
| plugin, | |||
| tool_providers, | |||
| workspace, | |||
| ) | |||
| @@ -2,6 +2,8 @@ from functools import wraps | |||
| from flask import request | |||
| from flask_restful import Resource, reqparse # type: ignore | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| from configs import dify_config | |||
| @@ -54,7 +56,8 @@ class InsertExploreAppListApi(Resource): | |||
| parser.add_argument("position", type=int, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| app = App.query.filter(App.id == args["app_id"]).first() | |||
| with Session(db.engine) as session: | |||
| app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() | |||
| if not app: | |||
| raise NotFound(f"App '{args['app_id']}' is not found") | |||
| @@ -70,7 +73,10 @@ class InsertExploreAppListApi(Resource): | |||
| privacy_policy = site.privacy_policy or args["privacy_policy"] or "" | |||
| custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" | |||
| recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() | |||
| with Session(db.engine) as session: | |||
| recommended_app = session.execute( | |||
| select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) | |||
| ).scalar_one_or_none() | |||
| if not recommended_app: | |||
| recommended_app = RecommendedApp( | |||
| @@ -110,17 +116,27 @@ class InsertExploreAppApi(Resource): | |||
| @only_edition_cloud | |||
| @admin_required | |||
| def delete(self, app_id): | |||
| recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() | |||
| with Session(db.engine) as session: | |||
| recommended_app = session.execute( | |||
| select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id)) | |||
| ).scalar_one_or_none() | |||
| if not recommended_app: | |||
| return {"result": "success"}, 204 | |||
| app = App.query.filter(App.id == recommended_app.app_id).first() | |||
| with Session(db.engine) as session: | |||
| app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none() | |||
| if app: | |||
| app.is_public = False | |||
| installed_apps = InstalledApp.query.filter( | |||
| InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id | |||
| ).all() | |||
| with Session(db.engine) as session: | |||
| installed_apps = session.execute( | |||
| select(InstalledApp).filter( | |||
| InstalledApp.app_id == recommended_app.app_id, | |||
| InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id, | |||
| ) | |||
| ).all() | |||
| for installed_app in installed_apps: | |||
| db.session.delete(installed_app) | |||
| @@ -3,6 +3,8 @@ from typing import Any | |||
| import flask_restful # type: ignore | |||
| from flask_login import current_user # type: ignore | |||
| from flask_restful import Resource, fields, marshal_with | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden | |||
| from extensions.ext_database import db | |||
| @@ -26,7 +28,16 @@ api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="it | |||
| def _get_resource(resource_id, tenant_id, resource_model): | |||
| resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() | |||
| if resource_model == App: | |||
| with Session(db.engine) as session: | |||
| resource = session.execute( | |||
| select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) | |||
| ).scalar_one_or_none() | |||
| else: | |||
| with Session(db.engine) as session: | |||
| resource = session.execute( | |||
| select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) | |||
| ).scalar_one_or_none() | |||
| if resource is None: | |||
| flask_restful.abort(404, message=f"{resource_model.__name__} not found.") | |||
| @@ -5,14 +5,16 @@ from flask_restful import Resource, marshal_with, reqparse # type: ignore | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| setup_required, | |||
| ) | |||
| from extensions.ext_database import db | |||
| from fields.app_fields import app_import_fields | |||
| from fields.app_fields import app_import_check_dependencies_fields, app_import_fields | |||
| from libs.login import login_required | |||
| from models import Account | |||
| from models.model import App | |||
| from services.app_dsl_service import AppDslService, ImportStatus | |||
| @@ -88,3 +90,20 @@ class AppImportConfirmApi(Resource): | |||
| if result.status == ImportStatus.FAILED.value: | |||
| return result.model_dump(mode="json"), 400 | |||
| return result.model_dump(mode="json"), 200 | |||
| class AppImportCheckDependenciesApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @get_app_model | |||
| @account_initialization_required | |||
| @marshal_with(app_import_check_dependencies_fields) | |||
| def get(self, app_model: App): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| with Session(db.engine) as session: | |||
| import_service = AppDslService(session) | |||
| result = import_service.check_dependencies(app_model=app_model) | |||
| return result.model_dump(mode="json"), 200 | |||
| @@ -2,6 +2,7 @@ from datetime import UTC, datetime | |||
| from flask_login import current_user # type: ignore | |||
| from flask_restful import Resource, marshal_with, reqparse # type: ignore | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden, NotFound | |||
| from constants.languages import supported_language | |||
| @@ -50,33 +51,37 @@ class AppSite(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| site = Site.query.filter(Site.app_id == app_model.id).one_or_404() | |||
| for attr_name in [ | |||
| "title", | |||
| "icon_type", | |||
| "icon", | |||
| "icon_background", | |||
| "description", | |||
| "default_language", | |||
| "chat_color_theme", | |||
| "chat_color_theme_inverted", | |||
| "customize_domain", | |||
| "copyright", | |||
| "privacy_policy", | |||
| "custom_disclaimer", | |||
| "customize_token_strategy", | |||
| "prompt_public", | |||
| "show_workflow_steps", | |||
| "use_icon_as_answer_icon", | |||
| ]: | |||
| value = args.get(attr_name) | |||
| if value is not None: | |||
| setattr(site, attr_name, value) | |||
| site.updated_by = current_user.id | |||
| site.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| with Session(db.engine) as session: | |||
| site = session.query(Site).filter(Site.app_id == app_model.id).first() | |||
| if not site: | |||
| raise NotFound | |||
| for attr_name in [ | |||
| "title", | |||
| "icon_type", | |||
| "icon", | |||
| "icon_background", | |||
| "description", | |||
| "default_language", | |||
| "chat_color_theme", | |||
| "chat_color_theme_inverted", | |||
| "customize_domain", | |||
| "copyright", | |||
| "privacy_policy", | |||
| "custom_disclaimer", | |||
| "customize_token_strategy", | |||
| "prompt_public", | |||
| "show_workflow_steps", | |||
| "use_icon_as_answer_icon", | |||
| ]: | |||
| value = args.get(attr_name) | |||
| if value is not None: | |||
| setattr(site, attr_name, value) | |||
| site.updated_by = current_user.id | |||
| site.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| session.commit() | |||
| return site | |||
| @@ -20,6 +20,7 @@ from libs import helper | |||
| from libs.helper import TimestampField, uuid_value | |||
| from libs.login import current_user, login_required | |||
| from models import App | |||
| from models.account import Account | |||
| from models.model import AppMode | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.app import WorkflowHashNotEqualError | |||
| @@ -96,6 +97,9 @@ class DraftWorkflowApi(Resource): | |||
| else: | |||
| abort(415) | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| workflow_service = WorkflowService() | |||
| try: | |||
| @@ -139,6 +143,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, location="json") | |||
| parser.add_argument("query", type=str, required=True, location="json", default="") | |||
| @@ -160,7 +167,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): | |||
| raise ConversationCompletedError() | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -178,6 +185,9 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, location="json") | |||
| args = parser.parse_args() | |||
| @@ -194,7 +204,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): | |||
| raise ConversationCompletedError() | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -212,6 +222,9 @@ class WorkflowDraftRunIterationNodeApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, location="json") | |||
| args = parser.parse_args() | |||
| @@ -228,7 +241,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): | |||
| raise ConversationCompletedError() | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -246,6 +259,9 @@ class DraftWorkflowRunApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| @@ -294,13 +310,20 @@ class DraftWorkflowNodeRunApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| inputs = args.get("inputs") | |||
| if inputs == None: | |||
| raise ValueError("missing inputs") | |||
| workflow_service = WorkflowService() | |||
| workflow_node_execution = workflow_service.run_draft_workflow_node( | |||
| app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user | |||
| app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user | |||
| ) | |||
| return workflow_node_execution | |||
| @@ -339,6 +362,9 @@ class PublishedWorkflowApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| workflow_service = WorkflowService() | |||
| workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) | |||
| @@ -376,12 +402,17 @@ class DefaultBlockConfigApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("q", type=str, location="args") | |||
| args = parser.parse_args() | |||
| q = args.get("q") | |||
| filters = None | |||
| if args.get("q"): | |||
| if q: | |||
| try: | |||
| filters = json.loads(args.get("q", "")) | |||
| except json.JSONDecodeError: | |||
| @@ -407,6 +438,9 @@ class ConvertToWorkflowApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| if request.data: | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=False, nullable=True, location="json") | |||
| @@ -3,6 +3,8 @@ import secrets | |||
| from flask import request | |||
| from flask_restful import Resource, reqparse # type: ignore | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from constants.languages import languages | |||
| from controllers.console import api | |||
| @@ -43,7 +45,8 @@ class ForgotPasswordSendEmailApi(Resource): | |||
| else: | |||
| language = "en-US" | |||
| account = Account.query.filter_by(email=args["email"]).first() | |||
| with Session(db.engine) as session: | |||
| account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() | |||
| token = None | |||
| if account is None: | |||
| if FeatureService.get_system_features().is_allow_register: | |||
| @@ -116,7 +119,8 @@ class ForgotPasswordResetApi(Resource): | |||
| password_hashed = hash_password(new_password, salt) | |||
| base64_password_hashed = base64.b64encode(password_hashed).decode() | |||
| account = Account.query.filter_by(email=reset_data.get("email")).first() | |||
| with Session(db.engine) as session: | |||
| account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none() | |||
| if account: | |||
| account.password = base64_password_hashed | |||
| account.password_salt = base64_salt | |||
| @@ -137,7 +141,7 @@ class ForgotPasswordResetApi(Resource): | |||
| ) | |||
| except WorkSpaceNotAllowedCreateError: | |||
| pass | |||
| except AccountRegisterError as are: | |||
| except AccountRegisterError: | |||
| raise AccountInFreezeError() | |||
| return {"result": "success"} | |||
| @@ -5,6 +5,8 @@ from typing import Optional | |||
| import requests | |||
| from flask import current_app, redirect, request | |||
| from flask_restful import Resource # type: ignore | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Unauthorized | |||
| from configs import dify_config | |||
| @@ -135,7 +137,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> | |||
| account: Optional[Account] = Account.get_by_openid(provider, user_info.id) | |||
| if not account: | |||
| account = Account.query.filter_by(email=user_info.email).first() | |||
| with Session(db.engine) as session: | |||
| account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none() | |||
| return account | |||
| @@ -4,6 +4,8 @@ import json | |||
| from flask import request | |||
| from flask_login import current_user # type: ignore | |||
| from flask_restful import Resource, marshal_with, reqparse # type: ignore | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound | |||
| from controllers.console import api | |||
| @@ -76,7 +78,10 @@ class DataSourceApi(Resource): | |||
| def patch(self, binding_id, action): | |||
| binding_id = str(binding_id) | |||
| action = str(action) | |||
| data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() | |||
| with Session(db.engine) as session: | |||
| data_source_binding = session.execute( | |||
| select(DataSourceOauthBinding).filter_by(id=binding_id) | |||
| ).scalar_one_or_none() | |||
| if data_source_binding is None: | |||
| raise NotFound("Data source binding not found.") | |||
| # enable binding | |||
| @@ -108,47 +113,53 @@ class DataSourceNotionListApi(Resource): | |||
| def get(self): | |||
| dataset_id = request.args.get("dataset_id", default=None, type=str) | |||
| exist_page_ids = [] | |||
| # import notion in the exist dataset | |||
| if dataset_id: | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| if dataset.data_source_type != "notion_import": | |||
| raise ValueError("Dataset is not notion type.") | |||
| documents = Document.query.filter_by( | |||
| dataset_id=dataset_id, | |||
| tenant_id=current_user.current_tenant_id, | |||
| data_source_type="notion_import", | |||
| enabled=True, | |||
| with Session(db.engine) as session: | |||
| # import notion in the exist dataset | |||
| if dataset_id: | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| if dataset.data_source_type != "notion_import": | |||
| raise ValueError("Dataset is not notion type.") | |||
| documents = session.execute( | |||
| select(Document).filter_by( | |||
| dataset_id=dataset_id, | |||
| tenant_id=current_user.current_tenant_id, | |||
| data_source_type="notion_import", | |||
| enabled=True, | |||
| ) | |||
| ).all() | |||
| if documents: | |||
| for document in documents: | |||
| data_source_info = json.loads(document.data_source_info) | |||
| exist_page_ids.append(data_source_info["notion_page_id"]) | |||
| # get all authorized pages | |||
| data_source_bindings = session.scalars( | |||
| select(DataSourceOauthBinding).filter_by( | |||
| tenant_id=current_user.current_tenant_id, provider="notion", disabled=False | |||
| ) | |||
| ).all() | |||
| if documents: | |||
| for document in documents: | |||
| data_source_info = json.loads(document.data_source_info) | |||
| exist_page_ids.append(data_source_info["notion_page_id"]) | |||
| # get all authorized pages | |||
| data_source_bindings = DataSourceOauthBinding.query.filter_by( | |||
| tenant_id=current_user.current_tenant_id, provider="notion", disabled=False | |||
| ).all() | |||
| if not data_source_bindings: | |||
| return {"notion_info": []}, 200 | |||
| pre_import_info_list = [] | |||
| for data_source_binding in data_source_bindings: | |||
| source_info = data_source_binding.source_info | |||
| pages = source_info["pages"] | |||
| # Filter out already bound pages | |||
| for page in pages: | |||
| if page["page_id"] in exist_page_ids: | |||
| page["is_bound"] = True | |||
| else: | |||
| page["is_bound"] = False | |||
| pre_import_info = { | |||
| "workspace_name": source_info["workspace_name"], | |||
| "workspace_icon": source_info["workspace_icon"], | |||
| "workspace_id": source_info["workspace_id"], | |||
| "pages": pages, | |||
| } | |||
| pre_import_info_list.append(pre_import_info) | |||
| return {"notion_info": pre_import_info_list}, 200 | |||
| if not data_source_bindings: | |||
| return {"notion_info": []}, 200 | |||
| pre_import_info_list = [] | |||
| for data_source_binding in data_source_bindings: | |||
| source_info = data_source_binding.source_info | |||
| pages = source_info["pages"] | |||
| # Filter out already bound pages | |||
| for page in pages: | |||
| if page["page_id"] in exist_page_ids: | |||
| page["is_bound"] = True | |||
| else: | |||
| page["is_bound"] = False | |||
| pre_import_info = { | |||
| "workspace_name": source_info["workspace_name"], | |||
| "workspace_icon": source_info["workspace_icon"], | |||
| "workspace_id": source_info["workspace_id"], | |||
| "pages": pages, | |||
| } | |||
| pre_import_info_list.append(pre_import_info) | |||
| return {"notion_info": pre_import_info_list}, 200 | |||
| class DataSourceNotionApi(Resource): | |||
| @@ -158,14 +169,17 @@ class DataSourceNotionApi(Resource): | |||
| def get(self, workspace_id, page_id, page_type): | |||
| workspace_id = str(workspace_id) | |||
| page_id = str(page_id) | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||
| ) | |||
| ).first() | |||
| with Session(db.engine) as session: | |||
| data_source_binding = session.execute( | |||
| select(DataSourceOauthBinding).filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||
| ) | |||
| ) | |||
| ).scalar_one_or_none() | |||
| if not data_source_binding: | |||
| raise NotFound("Data source binding not found.") | |||
| @@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required, enterpris | |||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.provider_manager import ProviderManager | |||
| from core.rag.datasource.vdb.vector_type import VectorType | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| @@ -72,7 +73,9 @@ class DatasetListApi(Resource): | |||
| data = marshal(datasets, dataset_detail_fields) | |||
| for item in data: | |||
| # convert embedding_model_provider to plugin standard format | |||
| if item["indexing_technique"] == "high_quality": | |||
| item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) | |||
| item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" | |||
| if item_model in model_names: | |||
| item["embedding_available"] = True | |||
| @@ -7,7 +7,6 @@ from flask import request | |||
| from flask_login import current_user # type: ignore | |||
| from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore | |||
| from sqlalchemy import asc, desc | |||
| from transformers.hf_argparser import string_to_bool # type: ignore | |||
| from werkzeug.exceptions import Forbidden, NotFound | |||
| import services | |||
| @@ -40,6 +39,7 @@ from core.indexing_runner import IndexingRunner | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.plugin.manager.exc import PluginDaemonClientSideError | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| @@ -150,8 +150,20 @@ class DatasetDocumentListApi(Resource): | |||
| sort = request.args.get("sort", default="-created_at", type=str) | |||
| # "yes", "true", "t", "y", "1" convert to True, while others convert to False. | |||
| try: | |||
| fetch = string_to_bool(request.args.get("fetch", default="false")) | |||
| except (ArgumentTypeError, ValueError, Exception) as e: | |||
| fetch_val = request.args.get("fetch", default="false") | |||
| if isinstance(fetch_val, bool): | |||
| fetch = fetch_val | |||
| else: | |||
| if fetch_val.lower() in ("yes", "true", "t", "y", "1"): | |||
| fetch = True | |||
| elif fetch_val.lower() in ("no", "false", "f", "n", "0"): | |||
| fetch = False | |||
| else: | |||
| raise ArgumentTypeError( | |||
| f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 " | |||
| f"(case insensitive)." | |||
| ) | |||
| except (ArgumentTypeError, ValueError, Exception): | |||
| fetch = False | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| @@ -429,6 +441,8 @@ class DocumentIndexingEstimateApi(DocumentResource): | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except PluginDaemonClientSideError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except Exception as e: | |||
| raise IndexingEstimateError(str(e)) | |||
| @@ -529,6 +543,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except PluginDaemonClientSideError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except Exception as e: | |||
| raise IndexingEstimateError(str(e)) | |||
| @@ -2,8 +2,11 @@ import os | |||
| from flask import session | |||
| from flask_restful import Resource, reqparse # type: ignore | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from configs import dify_config | |||
| from extensions.ext_database import db | |||
| from libs.helper import StrLen | |||
| from models.model import DifySetup | |||
| from services.account_service import TenantService | |||
| @@ -42,7 +45,11 @@ class InitValidateAPI(Resource): | |||
| def get_init_validate_status(): | |||
| if dify_config.EDITION == "SELF_HOSTED": | |||
| if os.environ.get("INIT_PASSWORD"): | |||
| return session.get("is_init_validated") or DifySetup.query.first() | |||
| if session.get("is_init_validated"): | |||
| return True | |||
| with Session(db.engine) as db_session: | |||
| return db_session.execute(select(DifySetup)).scalar_one_or_none() | |||
| return True | |||
| @@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore | |||
| from configs import dify_config | |||
| from libs.helper import StrLen, email, extract_remote_ip | |||
| from libs.password import valid_password | |||
| from models.model import DifySetup | |||
| from models.model import DifySetup, db | |||
| from services.account_service import RegisterService, TenantService | |||
| from . import api | |||
| @@ -52,8 +52,9 @@ class SetupApi(Resource): | |||
| def get_setup_status(): | |||
| if dify_config.EDITION == "SELF_HOSTED": | |||
| return DifySetup.query.first() | |||
| return True | |||
| return db.session.query(DifySetup).first() | |||
| else: | |||
| return True | |||
| api.add_resource(SetupApi, "/setup") | |||
| @@ -0,0 +1,56 @@ | |||
| from functools import wraps | |||
| from flask_login import current_user # type: ignore | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden | |||
| from extensions.ext_database import db | |||
| from models.account import TenantPluginPermission | |||
| def plugin_permission_required( | |||
| install_required: bool = False, | |||
| debug_required: bool = False, | |||
| ): | |||
| def interceptor(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| user = current_user | |||
| tenant_id = user.current_tenant_id | |||
| with Session(db.engine) as session: | |||
| permission = ( | |||
| session.query(TenantPluginPermission) | |||
| .filter( | |||
| TenantPluginPermission.tenant_id == tenant_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not permission: | |||
| # no permission set, allow access for everyone | |||
| return view(*args, **kwargs) | |||
| if install_required: | |||
| if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY: | |||
| raise Forbidden() | |||
| if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS: | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE: | |||
| pass | |||
| if debug_required: | |||
| if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY: | |||
| raise Forbidden() | |||
| if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS: | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE: | |||
| pass | |||
| return view(*args, **kwargs) | |||
| return decorated | |||
| return interceptor | |||
| @@ -0,0 +1,36 @@ | |||
| from flask_login import current_user # type: ignore | |||
| from flask_restful import Resource # type: ignore | |||
| from controllers.console import api | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from libs.login import login_required | |||
| from services.agent_service import AgentService | |||
| class AgentProviderListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) | |||
| class AgentProviderApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider_name: str): | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name)) | |||
| api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers") | |||
| api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>") | |||
| @@ -0,0 +1,205 @@ | |||
| from flask_login import current_user # type: ignore | |||
| from flask_restful import Resource, reqparse # type: ignore | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from libs.login import login_required | |||
| from services.plugin.endpoint_service import EndpointService | |||
| class EndpointCreateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| user = current_user | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("plugin_unique_identifier", type=str, required=True) | |||
| parser.add_argument("settings", type=dict, required=True) | |||
| parser.add_argument("name", type=str, required=True) | |||
| args = parser.parse_args() | |||
| plugin_unique_identifier = args["plugin_unique_identifier"] | |||
| settings = args["settings"] | |||
| name = args["name"] | |||
| return { | |||
| "success": EndpointService.create_endpoint( | |||
| tenant_id=user.current_tenant_id, | |||
| user_id=user.id, | |||
| plugin_unique_identifier=plugin_unique_identifier, | |||
| name=name, | |||
| settings=settings, | |||
| ) | |||
| } | |||
| class EndpointListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("page", type=int, required=True, location="args") | |||
| parser.add_argument("page_size", type=int, required=True, location="args") | |||
| args = parser.parse_args() | |||
| page = args["page"] | |||
| page_size = args["page_size"] | |||
| return jsonable_encoder( | |||
| { | |||
| "endpoints": EndpointService.list_endpoints( | |||
| tenant_id=user.current_tenant_id, | |||
| user_id=user.id, | |||
| page=page, | |||
| page_size=page_size, | |||
| ) | |||
| } | |||
| ) | |||
| class EndpointListForSinglePluginApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("page", type=int, required=True, location="args") | |||
| parser.add_argument("page_size", type=int, required=True, location="args") | |||
| parser.add_argument("plugin_id", type=str, required=True, location="args") | |||
| args = parser.parse_args() | |||
| page = args["page"] | |||
| page_size = args["page_size"] | |||
| plugin_id = args["plugin_id"] | |||
| return jsonable_encoder( | |||
| { | |||
| "endpoints": EndpointService.list_endpoints_for_single_plugin( | |||
| tenant_id=user.current_tenant_id, | |||
| user_id=user.id, | |||
| plugin_id=plugin_id, | |||
| page=page, | |||
| page_size=page_size, | |||
| ) | |||
| } | |||
| ) | |||
| class EndpointDeleteApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| user = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("endpoint_id", type=str, required=True) | |||
| args = parser.parse_args() | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| endpoint_id = args["endpoint_id"] | |||
| return { | |||
| "success": EndpointService.delete_endpoint( | |||
| tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id | |||
| ) | |||
| } | |||
| class EndpointUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| user = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("endpoint_id", type=str, required=True) | |||
| parser.add_argument("settings", type=dict, required=True) | |||
| parser.add_argument("name", type=str, required=True) | |||
| args = parser.parse_args() | |||
| endpoint_id = args["endpoint_id"] | |||
| settings = args["settings"] | |||
| name = args["name"] | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| return { | |||
| "success": EndpointService.update_endpoint( | |||
| tenant_id=user.current_tenant_id, | |||
| user_id=user.id, | |||
| endpoint_id=endpoint_id, | |||
| name=name, | |||
| settings=settings, | |||
| ) | |||
| } | |||
| class EndpointEnableApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| user = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("endpoint_id", type=str, required=True) | |||
| args = parser.parse_args() | |||
| endpoint_id = args["endpoint_id"] | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| return { | |||
| "success": EndpointService.enable_endpoint( | |||
| tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id | |||
| ) | |||
| } | |||
| class EndpointDisableApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| user = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("endpoint_id", type=str, required=True) | |||
| args = parser.parse_args() | |||
| endpoint_id = args["endpoint_id"] | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| return { | |||
| "success": EndpointService.disable_endpoint( | |||
| tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id | |||
| ) | |||
| } | |||
| api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create") | |||
| api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list") | |||
| api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin") | |||
| api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete") | |||
| api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update") | |||
| api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable") | |||
| api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable") | |||
| @@ -112,10 +112,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): | |||
| # Load Balancing Config | |||
| api.add_resource( | |||
| LoadBalancingCredentialsValidateApi, | |||
| "/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate", | |||
| "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate", | |||
| ) | |||
| api.add_resource( | |||
| LoadBalancingConfigCredentialsValidateApi, | |||
| "/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate", | |||
| "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate", | |||
| ) | |||
| @@ -79,7 +79,7 @@ class ModelProviderValidateApi(Resource): | |||
| response = {"result": "success" if result else "error"} | |||
| if not result: | |||
| response["error"] = error | |||
| response["error"] = error or "Unknown error" | |||
| return response | |||
| @@ -125,9 +125,10 @@ class ModelProviderIconApi(Resource): | |||
| Get model provider icon | |||
| """ | |||
| def get(self, provider: str, icon_type: str, lang: str): | |||
| def get(self, tenant_id: str, provider: str, icon_type: str, lang: str): | |||
| model_provider_service = ModelProviderService() | |||
| icon, mimetype = model_provider_service.get_model_provider_icon( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| icon_type=icon_type, | |||
| lang=lang, | |||
| @@ -183,53 +184,17 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): | |||
| return data | |||
| class ModelProviderFreeQuotaSubmitApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| model_provider_service = ModelProviderService() | |||
| result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider) | |||
| return result | |||
| class ModelProviderFreeQuotaQualificationVerifyApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("token", type=str, required=False, nullable=True, location="args") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| result = model_provider_service.free_quota_qualification_verify( | |||
| tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"] | |||
| ) | |||
| return result | |||
| api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") | |||
| api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials") | |||
| api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate") | |||
| api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>") | |||
| api.add_resource( | |||
| ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>" | |||
| ) | |||
| api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials") | |||
| api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate") | |||
| api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>") | |||
| api.add_resource( | |||
| PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type" | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url" | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit" | |||
| PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type" | |||
| ) | |||
| api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url") | |||
| api.add_resource( | |||
| ModelProviderFreeQuotaQualificationVerifyApi, | |||
| "/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify", | |||
| ModelProviderIconApi, | |||
| "/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>", | |||
| ) | |||
| @@ -325,7 +325,7 @@ class ModelProviderModelValidateApi(Resource): | |||
| response = {"result": "success" if result else "error"} | |||
| if not result: | |||
| response["error"] = error | |||
| response["error"] = error or "" | |||
| return response | |||
| @@ -362,26 +362,26 @@ class ModelProviderAvailableModelApi(Resource): | |||
| return jsonable_encoder({"data": models}) | |||
| api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models") | |||
| api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models") | |||
| api.add_resource( | |||
| ModelProviderModelEnableApi, | |||
| "/workspaces/current/model-providers/<string:provider>/models/enable", | |||
| "/workspaces/current/model-providers/<path:provider>/models/enable", | |||
| endpoint="model-provider-model-enable", | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelDisableApi, | |||
| "/workspaces/current/model-providers/<string:provider>/models/disable", | |||
| "/workspaces/current/model-providers/<path:provider>/models/disable", | |||
| endpoint="model-provider-model-disable", | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials" | |||
| ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials" | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate" | |||
| ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate" | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules" | |||
| ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules" | |||
| ) | |||
| api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>") | |||
| api.add_resource(DefaultModelApi, "/workspaces/current/default-model") | |||
| @@ -0,0 +1,475 @@ | |||
| import io | |||
| from flask import request, send_file | |||
| from flask_login import current_user # type: ignore | |||
| from flask_restful import Resource, reqparse # type: ignore | |||
| from werkzeug.exceptions import Forbidden | |||
| from configs import dify_config | |||
| from controllers.console import api | |||
| from controllers.console.workspace import plugin_permission_required | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.plugin.manager.exc import PluginDaemonClientSideError | |||
| from libs.login import login_required | |||
| from models.account import TenantPluginPermission | |||
| from services.plugin.plugin_permission_service import PluginPermissionService | |||
| from services.plugin.plugin_service import PluginService | |||
| class PluginDebuggingKeyApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(debug_required=True) | |||
| def get(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| try: | |||
| return { | |||
| "key": PluginService.get_debugging_key(tenant_id), | |||
| "host": dify_config.PLUGIN_REMOTE_INSTALL_HOST, | |||
| "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT, | |||
| } | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| class PluginListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| try: | |||
| plugins = PluginService.list(tenant_id) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| return jsonable_encoder({"plugins": plugins}) | |||
| class PluginListInstallationsFromIdsApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("plugin_ids", type=list, required=True, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"]) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| return jsonable_encoder({"plugins": plugins}) | |||
| class PluginIconApi(Resource): | |||
| @setup_required | |||
| def get(self): | |||
| req = reqparse.RequestParser() | |||
| req.add_argument("tenant_id", type=str, required=True, location="args") | |||
| req.add_argument("filename", type=str, required=True, location="args") | |||
| args = req.parse_args() | |||
| try: | |||
| icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"]) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE | |||
| return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) | |||
| class PluginUploadFromPkgApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(install_required=True) | |||
| def post(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| file = request.files["pkg"] | |||
| # check file size | |||
| if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE: | |||
| raise ValueError("File size exceeds the maximum allowed size") | |||
| content = file.read() | |||
| try: | |||
| response = PluginService.upload_pkg(tenant_id, content) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| return jsonable_encoder(response) | |||
| class PluginUploadFromGithubApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(install_required=True) | |||
| def post(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("repo", type=str, required=True, location="json") | |||
| parser.add_argument("version", type=str, required=True, location="json") | |||
| parser.add_argument("package", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"]) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| return jsonable_encoder(response) | |||
| class PluginUploadFromBundleApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(install_required=True) | |||
| def post(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| file = request.files["bundle"] | |||
| # check file size | |||
| if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE: | |||
| raise ValueError("File size exceeds the maximum allowed size") | |||
| content = file.read() | |||
| try: | |||
| response = PluginService.upload_bundle(tenant_id, content) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| return jsonable_encoder(response) | |||
| class PluginInstallFromPkgApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(install_required=True) | |||
| def post(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") | |||
| args = parser.parse_args() | |||
| # check if all plugin_unique_identifiers are valid string | |||
| for plugin_unique_identifier in args["plugin_unique_identifiers"]: | |||
| if not isinstance(plugin_unique_identifier, str): | |||
| raise ValueError("Invalid plugin unique identifier") | |||
| try: | |||
| response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"]) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| return jsonable_encoder(response) | |||
| class PluginInstallFromGithubApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(install_required=True) | |||
| def post(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("repo", type=str, required=True, location="json") | |||
| parser.add_argument("version", type=str, required=True, location="json") | |||
| parser.add_argument("package", type=str, required=True, location="json") | |||
| parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| response = PluginService.install_from_github( | |||
| tenant_id, | |||
| args["plugin_unique_identifier"], | |||
| args["repo"], | |||
| args["version"], | |||
| args["package"], | |||
| ) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| return jsonable_encoder(response) | |||
| class PluginInstallFromMarketplaceApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(install_required=True) | |||
| def post(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") | |||
| args = parser.parse_args() | |||
| # check if all plugin_unique_identifiers are valid string | |||
| for plugin_unique_identifier in args["plugin_unique_identifiers"]: | |||
| if not isinstance(plugin_unique_identifier, str): | |||
| raise ValueError("Invalid plugin unique identifier") | |||
| try: | |||
| response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"]) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| return jsonable_encoder(response) | |||
| class PluginFetchManifestApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(debug_required=True) | |||
| def get(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") | |||
| args = parser.parse_args() | |||
| try: | |||
| return jsonable_encoder( | |||
| { | |||
| "manifest": PluginService.fetch_plugin_manifest( | |||
| tenant_id, args["plugin_unique_identifier"] | |||
| ).model_dump() | |||
| } | |||
| ) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| class PluginFetchInstallTasksApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(debug_required=True) | |||
| def get(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("page", type=int, required=True, location="args") | |||
| parser.add_argument("page_size", type=int, required=True, location="args") | |||
| args = parser.parse_args() | |||
| try: | |||
| return jsonable_encoder( | |||
| {"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])} | |||
| ) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| class PluginFetchInstallTaskApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(debug_required=True) | |||
| def get(self, task_id: str): | |||
| tenant_id = current_user.current_tenant_id | |||
| try: | |||
| return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| class PluginDeleteInstallTaskApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(debug_required=True) | |||
| def post(self, task_id: str): | |||
| tenant_id = current_user.current_tenant_id | |||
| try: | |||
| return {"success": PluginService.delete_install_task(tenant_id, task_id)} | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| class PluginDeleteAllInstallTaskItemsApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(debug_required=True) | |||
| def post(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| try: | |||
| return {"success": PluginService.delete_all_install_task_items(tenant_id)} | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| class PluginDeleteInstallTaskItemApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(debug_required=True) | |||
| def post(self, task_id: str, identifier: str): | |||
| tenant_id = current_user.current_tenant_id | |||
| try: | |||
| return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| class PluginUpgradeFromMarketplaceApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(debug_required=True) | |||
| def post(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") | |||
| parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| return jsonable_encoder( | |||
| PluginService.upgrade_plugin_with_marketplace( | |||
| tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"] | |||
| ) | |||
| ) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| class PluginUpgradeFromGithubApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(debug_required=True) | |||
| def post(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") | |||
| parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") | |||
| parser.add_argument("repo", type=str, required=True, location="json") | |||
| parser.add_argument("version", type=str, required=True, location="json") | |||
| parser.add_argument("package", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| return jsonable_encoder( | |||
| PluginService.upgrade_plugin_with_github( | |||
| tenant_id, | |||
| args["original_plugin_unique_identifier"], | |||
| args["new_plugin_unique_identifier"], | |||
| args["repo"], | |||
| args["version"], | |||
| args["package"], | |||
| ) | |||
| ) | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| class PluginUninstallApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @plugin_permission_required(debug_required=True) | |||
| def post(self): | |||
| req = reqparse.RequestParser() | |||
| req.add_argument("plugin_installation_id", type=str, required=True, location="json") | |||
| args = req.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| try: | |||
| return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])} | |||
| except PluginDaemonClientSideError as e: | |||
| raise ValueError(e) | |||
| class PluginChangePermissionApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| user = current_user | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| req = reqparse.RequestParser() | |||
| req.add_argument("install_permission", type=str, required=True, location="json") | |||
| req.add_argument("debug_permission", type=str, required=True, location="json") | |||
| args = req.parse_args() | |||
| install_permission = TenantPluginPermission.InstallPermission(args["install_permission"]) | |||
| debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"]) | |||
| tenant_id = user.current_tenant_id | |||
| return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} | |||
| class PluginFetchPermissionApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| permission = PluginPermissionService.get_permission(tenant_id) | |||
| if not permission: | |||
| return jsonable_encoder( | |||
| { | |||
| "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, | |||
| "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, | |||
| } | |||
| ) | |||
| return jsonable_encoder( | |||
| { | |||
| "install_permission": permission.install_permission, | |||
| "debug_permission": permission.debug_permission, | |||
| } | |||
| ) | |||
| api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key") | |||
| api.add_resource(PluginListApi, "/workspaces/current/plugin/list") | |||
| api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids") | |||
| api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon") | |||
| api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg") | |||
| api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github") | |||
| api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle") | |||
| api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg") | |||
| api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github") | |||
| api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace") | |||
| api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github") | |||
| api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace") | |||
| api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest") | |||
| api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks") | |||
| api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>") | |||
| api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete") | |||
| api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all") | |||
| api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>") | |||
| api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall") | |||
| api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change") | |||
| api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") | |||
| @@ -25,8 +25,10 @@ class ToolProviderListApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| req = reqparse.RequestParser() | |||
| req.add_argument( | |||
| @@ -47,28 +49,43 @@ class ToolBuiltinProviderListToolsApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user = current_user | |||
| tenant_id = user.current_tenant_id | |||
| return jsonable_encoder( | |||
| BuiltinToolManageService.list_builtin_tool_provider_tools( | |||
| user_id, | |||
| tenant_id, | |||
| provider, | |||
| ) | |||
| ) | |||
| class ToolBuiltinProviderInfoApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider): | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider)) | |||
| class ToolBuiltinProviderDeleteApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider): | |||
| if not current_user.is_admin_or_owner: | |||
| user = current_user | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| return BuiltinToolManageService.delete_builtin_tool_provider( | |||
| user_id, | |||
| @@ -82,11 +99,13 @@ class ToolBuiltinProviderUpdateApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider): | |||
| if not current_user.is_admin_or_owner: | |||
| user = current_user | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| @@ -131,11 +150,13 @@ class ToolApiProviderAddApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| user = current_user | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| @@ -168,6 +189,11 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("url", type=str, required=True, nullable=False, location="args") | |||
| @@ -175,8 +201,8 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): | |||
| args = parser.parse_args() | |||
| return ApiToolManageService.get_api_tool_provider_remote_schema( | |||
| current_user.id, | |||
| current_user.current_tenant_id, | |||
| user_id, | |||
| tenant_id, | |||
| args["url"], | |||
| ) | |||
| @@ -186,8 +212,10 @@ class ToolApiProviderListToolsApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| @@ -209,11 +237,13 @@ class ToolApiProviderUpdateApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| user = current_user | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| @@ -248,11 +278,13 @@ class ToolApiProviderDeleteApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| user = current_user | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| @@ -272,8 +304,10 @@ class ToolApiProviderGetApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| @@ -293,7 +327,11 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider): | |||
| return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) | |||
| user = current_user | |||
| tenant_id = user.current_tenant_id | |||
| return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) | |||
| class ToolApiProviderSchemaApi(Resource): | |||
| @@ -344,11 +382,13 @@ class ToolWorkflowProviderCreateApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| user = current_user | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| reqparser = reqparse.RequestParser() | |||
| reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| @@ -381,11 +421,13 @@ class ToolWorkflowProviderUpdateApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| user = current_user | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| reqparser = reqparse.RequestParser() | |||
| reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| @@ -421,11 +463,13 @@ class ToolWorkflowProviderDeleteApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| user = current_user | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| reqparser = reqparse.RequestParser() | |||
| reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| @@ -444,8 +488,10 @@ class ToolWorkflowProviderGetApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") | |||
| @@ -476,8 +522,10 @@ class ToolWorkflowProviderListToolApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") | |||
| @@ -498,8 +546,10 @@ class ToolBuiltinListApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| return jsonable_encoder( | |||
| [ | |||
| @@ -517,8 +567,10 @@ class ToolApiListApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| return jsonable_encoder( | |||
| [ | |||
| @@ -536,8 +588,10 @@ class ToolWorkflowListApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| user = current_user | |||
| user_id = user.id | |||
| tenant_id = user.current_tenant_id | |||
| return jsonable_encoder( | |||
| [ | |||
| @@ -563,16 +617,18 @@ class ToolLabelsApi(Resource): | |||
| api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") | |||
| # builtin tool provider | |||
| api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools") | |||
| api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete") | |||
| api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update") | |||
| api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools") | |||
| api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info") | |||
| api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete") | |||
| api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update") | |||
| api.add_resource( | |||
| ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials" | |||
| ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials" | |||
| ) | |||
| api.add_resource( | |||
| ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema" | |||
| ToolBuiltinProviderCredentialsSchemaApi, | |||
| "/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema", | |||
| ) | |||
| api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon") | |||
| api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon") | |||
| # api tool provider | |||
| api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") | |||
| @@ -7,6 +7,7 @@ from flask_login import current_user # type: ignore | |||
| from configs import dify_config | |||
| from controllers.console.workspace.error import AccountNotInitializedError | |||
| from extensions.ext_database import db | |||
| from models.model import DifySetup | |||
| from services.feature_service import FeatureService, LicenseStatus | |||
| from services.operation_service import OperationService | |||
| @@ -134,9 +135,13 @@ def setup_required(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| # check setup | |||
| if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first(): | |||
| if ( | |||
| dify_config.EDITION == "SELF_HOSTED" | |||
| and os.environ.get("INIT_PASSWORD") | |||
| and not db.session.query(DifySetup).first() | |||
| ): | |||
| raise NotInitValidateError() | |||
| elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first(): | |||
| elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first(): | |||
| raise NotSetupError() | |||
| return view(*args, **kwargs) | |||
| @@ -6,4 +6,4 @@ bp = Blueprint("files", __name__) | |||
| api = ExternalApi(bp) | |||
| from . import image_preview, tool_files | |||
| from . import image_preview, tool_files, upload | |||
| @@ -0,0 +1,69 @@ | |||
| from flask import request | |||
| from flask_restful import Resource, marshal_with # type: ignore | |||
| from werkzeug.exceptions import Forbidden | |||
| import services | |||
| from controllers.console.wraps import setup_required | |||
| from controllers.files import api | |||
| from controllers.files.error import UnsupportedFileTypeError | |||
| from controllers.inner_api.plugin.wraps import get_user | |||
| from controllers.service_api.app.error import FileTooLargeError | |||
| from core.file.helpers import verify_plugin_file_signature | |||
| from fields.file_fields import file_fields | |||
| from services.file_service import FileService | |||
| class PluginUploadFileApi(Resource): | |||
| @setup_required | |||
| @marshal_with(file_fields) | |||
| def post(self): | |||
| # get file from request | |||
| file = request.files["file"] | |||
| timestamp = request.args.get("timestamp") | |||
| nonce = request.args.get("nonce") | |||
| sign = request.args.get("sign") | |||
| tenant_id = request.args.get("tenant_id") | |||
| if not tenant_id: | |||
| raise Forbidden("Invalid request.") | |||
| user_id = request.args.get("user_id") | |||
| user = get_user(tenant_id, user_id) | |||
| filename = file.filename | |||
| mimetype = file.mimetype | |||
| if not filename or not mimetype: | |||
| raise Forbidden("Invalid request.") | |||
| if not timestamp or not nonce or not sign: | |||
| raise Forbidden("Invalid request.") | |||
| if not verify_plugin_file_signature( | |||
| filename=filename, | |||
| mimetype=mimetype, | |||
| tenant_id=tenant_id, | |||
| user_id=user_id, | |||
| timestamp=timestamp, | |||
| nonce=nonce, | |||
| sign=sign, | |||
| ): | |||
| raise Forbidden("Invalid request.") | |||
| try: | |||
| upload_file = FileService.upload_file( | |||
| filename=filename, | |||
| content=file.read(), | |||
| mimetype=mimetype, | |||
| user=user, | |||
| source=None, | |||
| ) | |||
| except services.errors.file.FileTooLargeError as file_too_large_error: | |||
| raise FileTooLargeError(file_too_large_error.description) | |||
| except services.errors.file.UnsupportedFileTypeError: | |||
| raise UnsupportedFileTypeError() | |||
| return upload_file, 201 | |||
| api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin") | |||
| @@ -5,4 +5,5 @@ from libs.external_api import ExternalApi | |||
| bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") | |||
| api = ExternalApi(bp) | |||
| from .plugin import plugin | |||
| from .workspace import workspace | |||
| @@ -0,0 +1,293 @@ | |||
| from flask_restful import Resource # type: ignore | |||
| from controllers.console.wraps import setup_required | |||
| from controllers.inner_api import api | |||
| from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data | |||
| from controllers.inner_api.wraps import plugin_inner_api_only | |||
| from core.file.helpers import get_signed_file_url_for_plugin | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation | |||
| from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse | |||
| from core.plugin.backwards_invocation.encrypt import PluginEncrypter | |||
| from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation | |||
| from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation | |||
| from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation | |||
| from core.plugin.entities.request import ( | |||
| RequestInvokeApp, | |||
| RequestInvokeEncrypt, | |||
| RequestInvokeLLM, | |||
| RequestInvokeModeration, | |||
| RequestInvokeParameterExtractorNode, | |||
| RequestInvokeQuestionClassifierNode, | |||
| RequestInvokeRerank, | |||
| RequestInvokeSpeech2Text, | |||
| RequestInvokeSummary, | |||
| RequestInvokeTextEmbedding, | |||
| RequestInvokeTool, | |||
| RequestInvokeTTS, | |||
| RequestRequestUploadFile, | |||
| ) | |||
| from core.tools.entities.tool_entities import ToolProviderType | |||
| from libs.helper import compact_generate_response | |||
| from models.account import Account, Tenant | |||
| from models.model import EndUser | |||
| class PluginInvokeLLMApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeLLM) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM): | |||
| def generator(): | |||
| response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload) | |||
| return PluginModelBackwardsInvocation.convert_to_event_stream(response) | |||
| return compact_generate_response(generator()) | |||
| class PluginInvokeTextEmbeddingApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeTextEmbedding) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding): | |||
| try: | |||
| return jsonable_encoder( | |||
| BaseBackwardsInvocationResponse( | |||
| data=PluginModelBackwardsInvocation.invoke_text_embedding( | |||
| user_id=user_model.id, | |||
| tenant=tenant_model, | |||
| payload=payload, | |||
| ) | |||
| ) | |||
| ) | |||
| except Exception as e: | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| class PluginInvokeRerankApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeRerank) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank): | |||
| try: | |||
| return jsonable_encoder( | |||
| BaseBackwardsInvocationResponse( | |||
| data=PluginModelBackwardsInvocation.invoke_rerank( | |||
| user_id=user_model.id, | |||
| tenant=tenant_model, | |||
| payload=payload, | |||
| ) | |||
| ) | |||
| ) | |||
| except Exception as e: | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| class PluginInvokeTTSApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeTTS) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS): | |||
| def generator(): | |||
| response = PluginModelBackwardsInvocation.invoke_tts( | |||
| user_id=user_model.id, | |||
| tenant=tenant_model, | |||
| payload=payload, | |||
| ) | |||
| return PluginModelBackwardsInvocation.convert_to_event_stream(response) | |||
| return compact_generate_response(generator()) | |||
| class PluginInvokeSpeech2TextApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeSpeech2Text) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text): | |||
| try: | |||
| return jsonable_encoder( | |||
| BaseBackwardsInvocationResponse( | |||
| data=PluginModelBackwardsInvocation.invoke_speech2text( | |||
| user_id=user_model.id, | |||
| tenant=tenant_model, | |||
| payload=payload, | |||
| ) | |||
| ) | |||
| ) | |||
| except Exception as e: | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| class PluginInvokeModerationApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeModeration) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration): | |||
| try: | |||
| return jsonable_encoder( | |||
| BaseBackwardsInvocationResponse( | |||
| data=PluginModelBackwardsInvocation.invoke_moderation( | |||
| user_id=user_model.id, | |||
| tenant=tenant_model, | |||
| payload=payload, | |||
| ) | |||
| ) | |||
| ) | |||
| except Exception as e: | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| class PluginInvokeToolApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeTool) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool): | |||
| def generator(): | |||
| return PluginToolBackwardsInvocation.convert_to_event_stream( | |||
| PluginToolBackwardsInvocation.invoke_tool( | |||
| tenant_id=tenant_model.id, | |||
| user_id=user_model.id, | |||
| tool_type=ToolProviderType.value_of(payload.tool_type), | |||
| provider=payload.provider, | |||
| tool_name=payload.tool, | |||
| tool_parameters=payload.tool_parameters, | |||
| ), | |||
| ) | |||
| return compact_generate_response(generator()) | |||
| class PluginInvokeParameterExtractorNodeApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeParameterExtractorNode) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode): | |||
| try: | |||
| return jsonable_encoder( | |||
| BaseBackwardsInvocationResponse( | |||
| data=PluginNodeBackwardsInvocation.invoke_parameter_extractor( | |||
| tenant_id=tenant_model.id, | |||
| user_id=user_model.id, | |||
| parameters=payload.parameters, | |||
| model_config=payload.model, | |||
| instruction=payload.instruction, | |||
| query=payload.query, | |||
| ) | |||
| ) | |||
| ) | |||
| except Exception as e: | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| class PluginInvokeQuestionClassifierNodeApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode): | |||
| try: | |||
| return jsonable_encoder( | |||
| BaseBackwardsInvocationResponse( | |||
| data=PluginNodeBackwardsInvocation.invoke_question_classifier( | |||
| tenant_id=tenant_model.id, | |||
| user_id=user_model.id, | |||
| query=payload.query, | |||
| model_config=payload.model, | |||
| classes=payload.classes, | |||
| instruction=payload.instruction, | |||
| ) | |||
| ) | |||
| ) | |||
| except Exception as e: | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| class PluginInvokeAppApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeApp) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp): | |||
| response = PluginAppBackwardsInvocation.invoke_app( | |||
| app_id=payload.app_id, | |||
| user_id=user_model.id, | |||
| tenant_id=tenant_model.id, | |||
| conversation_id=payload.conversation_id, | |||
| query=payload.query, | |||
| stream=payload.response_mode == "streaming", | |||
| inputs=payload.inputs, | |||
| files=payload.files, | |||
| ) | |||
| return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response)) | |||
| class PluginInvokeEncryptApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeEncrypt) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt): | |||
| """ | |||
| encrypt or decrypt data | |||
| """ | |||
| try: | |||
| return BaseBackwardsInvocationResponse( | |||
| data=PluginEncrypter.invoke_encrypt(tenant_model, payload) | |||
| ).model_dump() | |||
| except Exception as e: | |||
| return BaseBackwardsInvocationResponse(error=str(e)).model_dump() | |||
| class PluginInvokeSummaryApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeSummary) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary): | |||
| try: | |||
| return BaseBackwardsInvocationResponse( | |||
| data={ | |||
| "summary": PluginModelBackwardsInvocation.invoke_summary( | |||
| user_id=user_model.id, | |||
| tenant=tenant_model, | |||
| payload=payload, | |||
| ) | |||
| } | |||
| ).model_dump() | |||
| except Exception as e: | |||
| return BaseBackwardsInvocationResponse(error=str(e)).model_dump() | |||
| class PluginUploadFileRequestApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestRequestUploadFile) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile): | |||
| # generate signed url | |||
| url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id) | |||
| return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() | |||
| api.add_resource(PluginInvokeLLMApi, "/invoke/llm") | |||
| api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") | |||
| api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") | |||
| api.add_resource(PluginInvokeTTSApi, "/invoke/tts") | |||
| api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text") | |||
| api.add_resource(PluginInvokeModerationApi, "/invoke/moderation") | |||
| api.add_resource(PluginInvokeToolApi, "/invoke/tool") | |||
| api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor") | |||
| api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier") | |||
| api.add_resource(PluginInvokeAppApi, "/invoke/app") | |||
| api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") | |||
| api.add_resource(PluginInvokeSummaryApi, "/invoke/summary") | |||
| api.add_resource(PluginUploadFileRequestApi, "/upload/file/request") | |||
| @@ -0,0 +1,116 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import Optional | |||
| from flask import request | |||
| from flask_restful import reqparse # type: ignore | |||
| from pydantic import BaseModel | |||
| from sqlalchemy.orm import Session | |||
| from extensions.ext_database import db | |||
| from models.account import Account, Tenant | |||
| from models.model import EndUser | |||
| from services.account_service import AccountService | |||
| def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: | |||
| try: | |||
| with Session(db.engine) as session: | |||
| if not user_id: | |||
| user_id = "DEFAULT-USER" | |||
| if user_id == "DEFAULT-USER": | |||
| user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first() | |||
| if not user_model: | |||
| user_model = EndUser( | |||
| tenant_id=tenant_id, | |||
| type="service_api", | |||
| is_anonymous=True if user_id == "DEFAULT-USER" else False, | |||
| session_id=user_id, | |||
| ) | |||
| session.add(user_model) | |||
| session.commit() | |||
| else: | |||
| user_model = AccountService.load_user(user_id) | |||
| if not user_model: | |||
| user_model = session.query(EndUser).filter(EndUser.id == user_id).first() | |||
| if not user_model: | |||
| raise ValueError("user not found") | |||
| except Exception: | |||
| raise ValueError("user not found") | |||
| return user_model | |||
| def get_user_tenant(view: Optional[Callable] = None): | |||
| def decorator(view_func): | |||
| @wraps(view_func) | |||
| def decorated_view(*args, **kwargs): | |||
| # fetch json body | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("tenant_id", type=str, required=True, location="json") | |||
| parser.add_argument("user_id", type=str, required=True, location="json") | |||
| kwargs = parser.parse_args() | |||
| user_id = kwargs.get("user_id") | |||
| tenant_id = kwargs.get("tenant_id") | |||
| if not tenant_id: | |||
| raise ValueError("tenant_id is required") | |||
| if not user_id: | |||
| user_id = "DEFAULT-USER" | |||
| del kwargs["tenant_id"] | |||
| del kwargs["user_id"] | |||
| try: | |||
| tenant_model = ( | |||
| db.session.query(Tenant) | |||
| .filter( | |||
| Tenant.id == tenant_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| except Exception: | |||
| raise ValueError("tenant not found") | |||
| if not tenant_model: | |||
| raise ValueError("tenant not found") | |||
| kwargs["tenant_model"] = tenant_model | |||
| kwargs["user_model"] = get_user(tenant_id, user_id) | |||
| return view_func(*args, **kwargs) | |||
| return decorated_view | |||
| if view is None: | |||
| return decorator | |||
| else: | |||
| return decorator(view) | |||
| def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]): | |||
| def decorator(view_func): | |||
| def decorated_view(*args, **kwargs): | |||
| try: | |||
| data = request.get_json() | |||
| except Exception: | |||
| raise ValueError("invalid json") | |||
| try: | |||
| payload = payload_type(**data) | |||
| except Exception as e: | |||
| raise ValueError(f"invalid payload: {str(e)}") | |||
| kwargs["payload"] = payload | |||
| return view_func(*args, **kwargs) | |||
| return decorated_view | |||
| if view is None: | |||
| return decorator | |||
| else: | |||
| return decorator(view) | |||
| @@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore | |||
| from controllers.console.wraps import setup_required | |||
| from controllers.inner_api import api | |||
| from controllers.inner_api.wraps import inner_api_only | |||
| from controllers.inner_api.wraps import enterprise_inner_api_only | |||
| from events.tenant_event import tenant_was_created | |||
| from models.account import Account | |||
| from services.account_service import TenantService | |||
| @@ -12,7 +12,7 @@ from services.account_service import TenantService | |||
| class EnterpriseWorkspace(Resource): | |||
| @setup_required | |||
| @inner_api_only | |||
| @enterprise_inner_api_only | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| @@ -33,7 +33,7 @@ class EnterpriseWorkspace(Resource): | |||
| class EnterpriseWorkspaceNoOwnerEmail(Resource): | |||
| @setup_required | |||
| @inner_api_only | |||
| @enterprise_inner_api_only | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| @@ -10,7 +10,7 @@ from extensions.ext_database import db | |||
| from models.model import EndUser | |||
| def inner_api_only(view): | |||
| def enterprise_inner_api_only(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not dify_config.INNER_API: | |||
| @@ -18,7 +18,7 @@ def inner_api_only(view): | |||
| # get header 'X-Inner-Api-Key' | |||
| inner_api_key = request.headers.get("X-Inner-Api-Key") | |||
| if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: | |||
| if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN: | |||
| abort(401) | |||
| return view(*args, **kwargs) | |||
| @@ -26,7 +26,7 @@ def inner_api_only(view): | |||
| return decorated | |||
| def inner_api_user_auth(view): | |||
| def enterprise_inner_api_user_auth(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not dify_config.INNER_API: | |||
| @@ -60,3 +60,19 @@ def inner_api_user_auth(view): | |||
| return view(*args, **kwargs) | |||
| return decorated | |||
| def plugin_inner_api_only(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not dify_config.PLUGIN_DAEMON_KEY: | |||
| abort(404) | |||
| # get header 'X-Inner-Api-Key' | |||
| inner_api_key = request.headers.get("X-Inner-Api-Key") | |||
| if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN: | |||
| abort(404) | |||
| return view(*args, **kwargs) | |||
| return decorated | |||
| @@ -1,7 +1,6 @@ | |||
| import json | |||
| import logging | |||
| import uuid | |||
| from datetime import UTC, datetime | |||
| from typing import Optional, Union, cast | |||
| from core.agent.entities import AgentEntity, AgentToolEntity | |||
| @@ -32,19 +31,16 @@ from core.model_runtime.entities import ( | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||
| from core.model_runtime.entities.model_entities import ModelFeature | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.prompt.utils.extract_thread_messages import extract_thread_messages | |||
| from core.tools.__base.tool import Tool | |||
| from core.tools.entities.tool_entities import ( | |||
| ToolParameter, | |||
| ToolRuntimeVariablePool, | |||
| ) | |||
| from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| from core.tools.tool.tool import Tool | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models.model import Conversation, Message, MessageAgentThought, MessageFile | |||
| from models.tools import ToolConversationVariables | |||
| logger = logging.getLogger(__name__) | |||
| @@ -62,11 +58,9 @@ class BaseAgentRunner(AppRunner): | |||
| queue_manager: AppQueueManager, | |||
| message: Message, | |||
| user_id: str, | |||
| model_instance: ModelInstance, | |||
| memory: Optional[TokenBufferMemory] = None, | |||
| prompt_messages: Optional[list[PromptMessage]] = None, | |||
| variables_pool: Optional[ToolRuntimeVariablePool] = None, | |||
| db_variables: Optional[ToolConversationVariables] = None, | |||
| model_instance: ModelInstance, | |||
| ) -> None: | |||
| self.tenant_id = tenant_id | |||
| self.application_generate_entity = application_generate_entity | |||
| @@ -79,8 +73,6 @@ class BaseAgentRunner(AppRunner): | |||
| self.user_id = user_id | |||
| self.memory = memory | |||
| self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) | |||
| self.variables_pool = variables_pool | |||
| self.db_variables_pool = db_variables | |||
| self.model_instance = model_instance | |||
| # init callback | |||
| @@ -141,11 +133,10 @@ class BaseAgentRunner(AppRunner): | |||
| agent_tool=tool, | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| ) | |||
| tool_entity.load_variables(self.variables_pool) | |||
| assert tool_entity.entity.description | |||
| message_tool = PromptMessageTool( | |||
| name=tool.tool_name, | |||
| description=tool_entity.description.llm if tool_entity.description else "", | |||
| description=tool_entity.entity.description.llm, | |||
| parameters={ | |||
| "type": "object", | |||
| "properties": {}, | |||
| @@ -153,7 +144,7 @@ class BaseAgentRunner(AppRunner): | |||
| }, | |||
| ) | |||
| parameters = tool_entity.get_all_runtime_parameters() | |||
| parameters = tool_entity.get_merged_runtime_parameters() | |||
| for parameter in parameters: | |||
| if parameter.form != ToolParameter.ToolParameterForm.LLM: | |||
| continue | |||
| @@ -186,9 +177,11 @@ class BaseAgentRunner(AppRunner): | |||
| """ | |||
| convert dataset retriever tool to prompt message tool | |||
| """ | |||
| assert tool.entity.description | |||
| prompt_tool = PromptMessageTool( | |||
| name=tool.identity.name if tool.identity else "unknown", | |||
| description=tool.description.llm if tool.description else "", | |||
| name=tool.entity.identity.name, | |||
| description=tool.entity.description.llm, | |||
| parameters={ | |||
| "type": "object", | |||
| "properties": {}, | |||
| @@ -234,8 +227,7 @@ class BaseAgentRunner(AppRunner): | |||
| # save prompt tool | |||
| prompt_messages_tools.append(prompt_tool) | |||
| # save tool entity | |||
| if dataset_tool.identity is not None: | |||
| tool_instances[dataset_tool.identity.name] = dataset_tool | |||
| tool_instances[dataset_tool.entity.identity.name] = dataset_tool | |||
| return tool_instances, prompt_messages_tools | |||
| @@ -320,24 +312,24 @@ class BaseAgentRunner(AppRunner): | |||
| def save_agent_thought( | |||
| self, | |||
| agent_thought: MessageAgentThought, | |||
| tool_name: str, | |||
| tool_input: Union[str, dict], | |||
| thought: str, | |||
| tool_name: str | None, | |||
| tool_input: Union[str, dict, None], | |||
| thought: str | None, | |||
| observation: Union[str, dict, None], | |||
| tool_invoke_meta: Union[str, dict, None], | |||
| answer: str, | |||
| answer: str | None, | |||
| messages_ids: list[str], | |||
| llm_usage: LLMUsage | None = None, | |||
| ): | |||
| """ | |||
| Save agent thought | |||
| """ | |||
| queried_thought = ( | |||
| updated_agent_thought = ( | |||
| db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() | |||
| ) | |||
| if not queried_thought: | |||
| raise ValueError(f"Agent thought {agent_thought.id} not found") | |||
| agent_thought = queried_thought | |||
| if not updated_agent_thought: | |||
| raise ValueError("agent thought not found") | |||
| agent_thought = updated_agent_thought | |||
| if thought: | |||
| agent_thought.thought = thought | |||
| @@ -349,39 +341,39 @@ class BaseAgentRunner(AppRunner): | |||
| if isinstance(tool_input, dict): | |||
| try: | |||
| tool_input = json.dumps(tool_input, ensure_ascii=False) | |||
| except Exception as e: | |||
| except Exception: | |||
| tool_input = json.dumps(tool_input) | |||
| agent_thought.tool_input = tool_input | |||
| updated_agent_thought.tool_input = tool_input | |||
| if observation: | |||
| if isinstance(observation, dict): | |||
| try: | |||
| observation = json.dumps(observation, ensure_ascii=False) | |||
| except Exception as e: | |||
| except Exception: | |||
| observation = json.dumps(observation) | |||
| agent_thought.observation = observation | |||
| updated_agent_thought.observation = observation | |||
| if answer: | |||
| agent_thought.answer = answer | |||
| if messages_ids is not None and len(messages_ids) > 0: | |||
| agent_thought.message_files = json.dumps(messages_ids) | |||
| updated_agent_thought.message_files = json.dumps(messages_ids) | |||
| if llm_usage: | |||
| agent_thought.message_token = llm_usage.prompt_tokens | |||
| agent_thought.message_price_unit = llm_usage.prompt_price_unit | |||
| agent_thought.message_unit_price = llm_usage.prompt_unit_price | |||
| agent_thought.answer_token = llm_usage.completion_tokens | |||
| agent_thought.answer_price_unit = llm_usage.completion_price_unit | |||
| agent_thought.answer_unit_price = llm_usage.completion_unit_price | |||
| agent_thought.tokens = llm_usage.total_tokens | |||
| agent_thought.total_price = llm_usage.total_price | |||
| updated_agent_thought.message_token = llm_usage.prompt_tokens | |||
| updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit | |||
| updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price | |||
| updated_agent_thought.answer_token = llm_usage.completion_tokens | |||
| updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit | |||
| updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price | |||
| updated_agent_thought.tokens = llm_usage.total_tokens | |||
| updated_agent_thought.total_price = llm_usage.total_price | |||
| # check if tool labels is not empty | |||
| labels = agent_thought.tool_labels or {} | |||
| tools = agent_thought.tool.split(";") if agent_thought.tool else [] | |||
| labels = updated_agent_thought.tool_labels or {} | |||
| tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else [] | |||
| for tool in tools: | |||
| if not tool: | |||
| continue | |||
| @@ -392,39 +384,17 @@ class BaseAgentRunner(AppRunner): | |||
| else: | |||
| labels[tool] = {"en_US": tool, "zh_Hans": tool} | |||
| agent_thought.tool_labels_str = json.dumps(labels) | |||
| updated_agent_thought.tool_labels_str = json.dumps(labels) | |||
| if tool_invoke_meta is not None: | |||
| if isinstance(tool_invoke_meta, dict): | |||
| try: | |||
| tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) | |||
| except Exception as e: | |||
| except Exception: | |||
| tool_invoke_meta = json.dumps(tool_invoke_meta) | |||
| agent_thought.tool_meta_str = tool_invoke_meta | |||
| db.session.commit() | |||
| db.session.close() | |||
| def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): | |||
| """ | |||
| convert tool variables to db variables | |||
| """ | |||
| queried_variables = ( | |||
| db.session.query(ToolConversationVariables) | |||
| .filter( | |||
| ToolConversationVariables.conversation_id == self.message.conversation_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not queried_variables: | |||
| return | |||
| updated_agent_thought.tool_meta_str = tool_invoke_meta | |||
| db_variables = queried_variables | |||
| db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) | |||
| db.session.commit() | |||
| db.session.close() | |||
| @@ -464,11 +434,11 @@ class BaseAgentRunner(AppRunner): | |||
| tool_call_response: list[ToolPromptMessage] = [] | |||
| try: | |||
| tool_inputs = json.loads(agent_thought.tool_input) | |||
| except Exception as e: | |||
| except Exception: | |||
| tool_inputs = {tool: {} for tool in tools} | |||
| try: | |||
| tool_responses = json.loads(agent_thought.observation) | |||
| except Exception as e: | |||
| except Exception: | |||
| tool_responses = dict.fromkeys(tools, agent_thought.observation) | |||
| for tool in tools: | |||
| @@ -515,7 +485,11 @@ class BaseAgentRunner(AppRunner): | |||
| files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() | |||
| if not files: | |||
| return UserPromptMessage(content=message.query) | |||
| file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) | |||
| if message.app_model_config: | |||
| file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) | |||
| else: | |||
| file_extra_config = None | |||
| if not file_extra_config: | |||
| return UserPromptMessage(content=message.query) | |||
| @@ -1,6 +1,6 @@ | |||
| import json | |||
| from abc import ABC, abstractmethod | |||
| from collections.abc import Generator, Mapping | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Any, Optional | |||
| from core.agent.base_agent_runner import BaseAgentRunner | |||
| @@ -18,8 +18,8 @@ from core.model_runtime.entities.message_entities import ( | |||
| ) | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | |||
| from core.tools.__base.tool import Tool | |||
| from core.tools.entities.tool_entities import ToolInvokeMeta | |||
| from core.tools.tool.tool import Tool | |||
| from core.tools.tool_engine import ToolEngine | |||
| from models.model import Message | |||
| @@ -27,11 +27,11 @@ from models.model import Message | |||
| class CotAgentRunner(BaseAgentRunner, ABC): | |||
| _is_first_iteration = True | |||
| _ignore_observation_providers = ["wenxin"] | |||
| _historic_prompt_messages: list[PromptMessage] | None = None | |||
| _agent_scratchpad: list[AgentScratchpadUnit] | None = None | |||
| _instruction: str = "" # FIXME this must be str for now | |||
| _query: str | None = None | |||
| _prompt_messages_tools: list[PromptMessageTool] = [] | |||
| _historic_prompt_messages: list[PromptMessage] | |||
| _agent_scratchpad: list[AgentScratchpadUnit] | |||
| _instruction: str | |||
| _query: str | |||
| _prompt_messages_tools: Sequence[PromptMessageTool] | |||
| def run( | |||
| self, | |||
| @@ -42,6 +42,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| """ | |||
| Run Cot agent application | |||
| """ | |||
| app_generate_entity = self.application_generate_entity | |||
| self._repack_app_generate_entity(app_generate_entity) | |||
| self._init_react_state(query) | |||
| @@ -54,17 +55,19 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| app_generate_entity.model_conf.stop.append("Observation") | |||
| app_config = self.app_config | |||
| assert app_config.agent | |||
| # init instruction | |||
| inputs = inputs or {} | |||
| instruction = app_config.prompt_template.simple_prompt_template | |||
| self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs) | |||
| instruction = app_config.prompt_template.simple_prompt_template or "" | |||
| self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) | |||
| iteration_step = 1 | |||
| max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 | |||
| # convert tools into ModelRuntime Tool format | |||
| tool_instances, self._prompt_messages_tools = self._init_prompt_tools() | |||
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | |||
| self._prompt_messages_tools = prompt_messages_tools | |||
| function_call_state = True | |||
| llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} | |||
| @@ -116,14 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| callbacks=[], | |||
| ) | |||
| if not isinstance(chunks, Generator): | |||
| raise ValueError("Expected streaming response from LLM") | |||
| # check llm result | |||
| if not chunks: | |||
| raise ValueError("failed to invoke llm") | |||
| usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None} | |||
| usage_dict: dict[str, Optional[LLMUsage]] = {} | |||
| react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) | |||
| scratchpad = AgentScratchpadUnit( | |||
| agent_response="", | |||
| @@ -143,25 +139,25 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| if isinstance(chunk, AgentScratchpadUnit.Action): | |||
| action = chunk | |||
| # detect action | |||
| if scratchpad.agent_response is not None: | |||
| scratchpad.agent_response += json.dumps(chunk.model_dump()) | |||
| assert scratchpad.agent_response is not None | |||
| scratchpad.agent_response += json.dumps(chunk.model_dump()) | |||
| scratchpad.action_str = json.dumps(chunk.model_dump()) | |||
| scratchpad.action = action | |||
| else: | |||
| if scratchpad.agent_response is not None: | |||
| scratchpad.agent_response += chunk | |||
| if scratchpad.thought is not None: | |||
| scratchpad.thought += chunk | |||
| assert scratchpad.agent_response is not None | |||
| scratchpad.agent_response += chunk | |||
| assert scratchpad.thought is not None | |||
| scratchpad.thought += chunk | |||
| yield LLMResultChunk( | |||
| model=self.model_config.model, | |||
| prompt_messages=prompt_messages, | |||
| system_fingerprint="", | |||
| delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), | |||
| ) | |||
| if scratchpad.thought is not None: | |||
| scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" | |||
| if self._agent_scratchpad is not None: | |||
| self._agent_scratchpad.append(scratchpad) | |||
| assert scratchpad.thought is not None | |||
| scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" | |||
| self._agent_scratchpad.append(scratchpad) | |||
| # get llm usage | |||
| if "usage" in usage_dict: | |||
| @@ -256,8 +252,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| answer=final_answer, | |||
| messages_ids=[], | |||
| ) | |||
| if self.variables_pool is not None and self.db_variables_pool is not None: | |||
| self.update_db_variables(self.variables_pool, self.db_variables_pool) | |||
| # publish end event | |||
| self.queue_manager.publish( | |||
| QueueMessageEndEvent( | |||
| @@ -275,7 +269,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| def _handle_invoke_action( | |||
| self, | |||
| action: AgentScratchpadUnit.Action, | |||
| tool_instances: dict[str, Tool], | |||
| tool_instances: Mapping[str, Tool], | |||
| message_file_ids: list[str], | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> tuple[str, ToolInvokeMeta]: | |||
| @@ -315,11 +309,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| ) | |||
| # publish files | |||
| for message_file_id, save_as in message_files: | |||
| if save_as is not None and self.variables_pool: | |||
| # FIXME the save_as type is confusing, it should be a string or not | |||
| self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as)) | |||
| for message_file_id in message_files: | |||
| # publish message file | |||
| self.queue_manager.publish( | |||
| QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER | |||
| @@ -342,7 +332,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| for key, value in inputs.items(): | |||
| try: | |||
| instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) | |||
| except Exception as e: | |||
| except Exception: | |||
| continue | |||
| return instruction | |||
| @@ -379,7 +369,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| return message | |||
| def _organize_historic_prompt_messages( | |||
| self, current_session_messages: Optional[list[PromptMessage]] = None | |||
| self, current_session_messages: list[PromptMessage] | None = None | |||
| ) -> list[PromptMessage]: | |||
| """ | |||
| organize historic prompt messages | |||
| @@ -391,8 +381,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| for message in self.history_prompt_messages: | |||
| if isinstance(message, AssistantPromptMessage): | |||
| if not current_scratchpad: | |||
| if not isinstance(message.content, str | None): | |||
| raise NotImplementedError("expected str type") | |||
| assert isinstance(message.content, str) | |||
| current_scratchpad = AgentScratchpadUnit( | |||
| agent_response=message.content, | |||
| thought=message.content or "I am thinking about how to help you", | |||
| @@ -411,9 +400,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| except: | |||
| pass | |||
| elif isinstance(message, ToolPromptMessage): | |||
| if not current_scratchpad: | |||
| continue | |||
| if isinstance(message.content, str): | |||
| if current_scratchpad: | |||
| assert isinstance(message.content, str) | |||
| current_scratchpad.observation = message.content | |||
| else: | |||
| raise NotImplementedError("expected str type") | |||
| @@ -19,8 +19,8 @@ class CotChatAgentRunner(CotAgentRunner): | |||
| """ | |||
| Organize system prompt | |||
| """ | |||
| if not self.app_config.agent: | |||
| raise ValueError("Agent configuration is not set") | |||
| assert self.app_config.agent | |||
| assert self.app_config.agent.prompt | |||
| prompt_entity = self.app_config.agent.prompt | |||
| if not prompt_entity: | |||
| @@ -83,8 +83,10 @@ class CotChatAgentRunner(CotAgentRunner): | |||
| assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str | |||
| for unit in agent_scratchpad: | |||
| if unit.is_final(): | |||
| assert isinstance(assistant_message.content, str) | |||
| assistant_message.content += f"Final Answer: {unit.agent_response}" | |||
| else: | |||
| assert isinstance(assistant_message.content, str) | |||
| assistant_message.content += f"Thought: {unit.thought}\n\n" | |||
| if unit.action_str: | |||
| assistant_message.content += f"Action: {unit.action_str}\n\n" | |||
| @@ -1,18 +1,21 @@ | |||
| from enum import Enum | |||
| from typing import Any, Literal, Optional, Union | |||
| from enum import StrEnum | |||
| from typing import Any, Optional, Union | |||
| from pydantic import BaseModel | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType | |||
| class AgentToolEntity(BaseModel): | |||
| """ | |||
| Agent Tool Entity. | |||
| """ | |||
| provider_type: Literal["builtin", "api", "workflow"] | |||
| provider_type: ToolProviderType | |||
| provider_id: str | |||
| tool_name: str | |||
| tool_parameters: dict[str, Any] = {} | |||
| plugin_unique_identifier: str | None = None | |||
| class AgentPromptEntity(BaseModel): | |||
| @@ -66,7 +69,7 @@ class AgentEntity(BaseModel): | |||
| Agent Entity. | |||
| """ | |||
| class Strategy(Enum): | |||
| class Strategy(StrEnum): | |||
| """ | |||
| Agent Strategy. | |||
| """ | |||
| @@ -78,5 +81,13 @@ class AgentEntity(BaseModel): | |||
| model: str | |||
| strategy: Strategy | |||
| prompt: Optional[AgentPromptEntity] = None | |||
| tools: list[AgentToolEntity] | None = None | |||
| tools: Optional[list[AgentToolEntity]] = None | |||
| max_iteration: int = 5 | |||
| class AgentInvokeMessage(ToolInvokeMessage): | |||
| """ | |||
| Agent Invoke Message. | |||
| """ | |||
| pass | |||
| @@ -46,18 +46,20 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| # convert tools into ModelRuntime Tool format | |||
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | |||
| assert app_config.agent | |||
| iteration_step = 1 | |||
| max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 | |||
| # continue to run until there is not any tool call | |||
| function_call_state = True | |||
| llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()} | |||
| llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} | |||
| final_answer = "" | |||
| # get tracing instance | |||
| trace_manager = app_generate_entity.trace_manager | |||
| def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): | |||
| def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): | |||
| if not final_llm_usage_dict["usage"]: | |||
| final_llm_usage_dict["usage"] = usage | |||
| else: | |||
| @@ -107,7 +109,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| current_llm_usage = None | |||
| if self.stream_tool_call and isinstance(chunks, Generator): | |||
| if isinstance(chunks, Generator): | |||
| is_first_chunk = True | |||
| for chunk in chunks: | |||
| if is_first_chunk: | |||
| @@ -124,7 +126,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| tool_call_inputs = json.dumps( | |||
| {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False | |||
| ) | |||
| except json.JSONDecodeError as e: | |||
| except json.JSONDecodeError: | |||
| # ensure ascii to avoid encoding error | |||
| tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) | |||
| @@ -140,7 +142,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| current_llm_usage = chunk.delta.usage | |||
| yield chunk | |||
| elif not self.stream_tool_call and isinstance(chunks, LLMResult): | |||
| else: | |||
| result = chunks | |||
| # check if there is any tool call | |||
| if self.check_blocking_tool_calls(result): | |||
| @@ -151,7 +153,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| tool_call_inputs = json.dumps( | |||
| {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False | |||
| ) | |||
| except json.JSONDecodeError as e: | |||
| except json.JSONDecodeError: | |||
| # ensure ascii to avoid encoding error | |||
| tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) | |||
| @@ -183,8 +185,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| usage=result.usage, | |||
| ), | |||
| ) | |||
| else: | |||
| raise RuntimeError(f"invalid chunks type: {type(chunks)}") | |||
| assistant_message = AssistantPromptMessage(content="", tool_calls=[]) | |||
| if tool_calls: | |||
| @@ -243,15 +243,12 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| agent_tool_callback=self.agent_callback, | |||
| trace_manager=trace_manager, | |||
| app_id=self.application_generate_entity.app_config.app_id, | |||
| message_id=self.message.id, | |||
| conversation_id=self.conversation.id, | |||
| ) | |||
| # publish files | |||
| for message_file_id, save_as in message_files: | |||
| if save_as: | |||
| if self.variables_pool: | |||
| self.variables_pool.set_file( | |||
| tool_name=tool_call_name, value=message_file_id, name=save_as | |||
| ) | |||
| for message_file_id in message_files: | |||
| # publish message file | |||
| self.queue_manager.publish( | |||
| QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER | |||
| @@ -303,8 +300,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| iteration_step += 1 | |||
| if self.variables_pool and self.db_variables_pool: | |||
| self.update_db_variables(self.variables_pool, self.db_variables_pool) | |||
| # publish end event | |||
| self.queue_manager.publish( | |||
| QueueMessageEndEvent( | |||
| @@ -335,9 +330,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| return True | |||
| return False | |||
| def extract_tool_calls( | |||
| self, llm_result_chunk: LLMResultChunk | |||
| ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: | |||
| def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]: | |||
| """ | |||
| Extract tool calls from llm result chunk | |||
| @@ -360,7 +353,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| return tool_calls | |||
| def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: | |||
| def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]: | |||
| """ | |||
| Extract blocking tool calls from llm result | |||
| @@ -383,9 +376,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| return tool_calls | |||
| def _init_system_message( | |||
| self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None | |||
| ) -> list[PromptMessage]: | |||
| def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | |||
| """ | |||
| Initialize system message | |||
| """ | |||
| @@ -0,0 +1,89 @@ | |||
| import enum | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator | |||
| from core.entities.parameter_entities import CommonParameterType | |||
| from core.plugin.entities.parameters import ( | |||
| PluginParameter, | |||
| as_normal_type, | |||
| cast_parameter_value, | |||
| init_frontend_parameter, | |||
| ) | |||
| from core.tools.entities.common_entities import I18nObject | |||
| from core.tools.entities.tool_entities import ( | |||
| ToolIdentity, | |||
| ToolProviderIdentity, | |||
| ) | |||
| class AgentStrategyProviderIdentity(ToolProviderIdentity): | |||
| """ | |||
| Inherits from ToolProviderIdentity, without any additional fields. | |||
| """ | |||
| pass | |||
| class AgentStrategyParameter(PluginParameter): | |||
| class AgentStrategyParameterType(enum.StrEnum): | |||
| """ | |||
| Keep all the types from PluginParameterType | |||
| """ | |||
| STRING = CommonParameterType.STRING.value | |||
| NUMBER = CommonParameterType.NUMBER.value | |||
| BOOLEAN = CommonParameterType.BOOLEAN.value | |||
| SELECT = CommonParameterType.SELECT.value | |||
| SECRET_INPUT = CommonParameterType.SECRET_INPUT.value | |||
| FILE = CommonParameterType.FILE.value | |||
| FILES = CommonParameterType.FILES.value | |||
| APP_SELECTOR = CommonParameterType.APP_SELECTOR.value | |||
| MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value | |||
| TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value | |||
| # deprecated, should not use. | |||
| SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value | |||
| def as_normal_type(self): | |||
| return as_normal_type(self) | |||
| def cast_value(self, value: Any): | |||
| return cast_parameter_value(self, value) | |||
| type: AgentStrategyParameterType = Field(..., description="The type of the parameter") | |||
| def init_frontend_parameter(self, value: Any): | |||
| return init_frontend_parameter(self, self.type, value) | |||
| class AgentStrategyProviderEntity(BaseModel): | |||
| identity: AgentStrategyProviderIdentity | |||
| plugin_id: Optional[str] = Field(None, description="The id of the plugin") | |||
| class AgentStrategyIdentity(ToolIdentity): | |||
| """ | |||
| Inherits from ToolIdentity, without any additional fields. | |||
| """ | |||
| pass | |||
| class AgentStrategyEntity(BaseModel): | |||
| identity: AgentStrategyIdentity | |||
| parameters: list[AgentStrategyParameter] = Field(default_factory=list) | |||
| description: I18nObject = Field(..., description="The description of the agent strategy") | |||
| output_schema: Optional[dict] = None | |||
| # pydantic configs | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| @field_validator("parameters", mode="before") | |||
| @classmethod | |||
| def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]: | |||
| return v or [] | |||
| class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity): | |||
| strategies: list[AgentStrategyEntity] = Field(default_factory=list) | |||
| @@ -0,0 +1,42 @@ | |||
| from abc import ABC, abstractmethod | |||
| from collections.abc import Generator, Sequence | |||
| from typing import Any, Optional | |||
| from core.agent.entities import AgentInvokeMessage | |||
| from core.agent.plugin_entities import AgentStrategyParameter | |||
| class BaseAgentStrategy(ABC): | |||
| """ | |||
| Agent Strategy | |||
| """ | |||
| def invoke( | |||
| self, | |||
| params: dict[str, Any], | |||
| user_id: str, | |||
| conversation_id: Optional[str] = None, | |||
| app_id: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| ) -> Generator[AgentInvokeMessage, None, None]: | |||
| """ | |||
| Invoke the agent strategy. | |||
| """ | |||
| yield from self._invoke(params, user_id, conversation_id, app_id, message_id) | |||
| def get_parameters(self) -> Sequence[AgentStrategyParameter]: | |||
| """ | |||
| Get the parameters for the agent strategy. | |||
| """ | |||
| return [] | |||
| @abstractmethod | |||
| def _invoke( | |||
| self, | |||
| params: dict[str, Any], | |||
| user_id: str, | |||
| conversation_id: Optional[str] = None, | |||
| app_id: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| ) -> Generator[AgentInvokeMessage, None, None]: | |||
| pass | |||
| @@ -0,0 +1,59 @@ | |||
| from collections.abc import Generator, Sequence | |||
| from typing import Any, Optional | |||
| from core.agent.entities import AgentInvokeMessage | |||
| from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter | |||
| from core.agent.strategy.base import BaseAgentStrategy | |||
| from core.plugin.manager.agent import PluginAgentManager | |||
| from core.plugin.utils.converter import convert_parameters_to_plugin_format | |||
| class PluginAgentStrategy(BaseAgentStrategy): | |||
| """ | |||
| Agent Strategy | |||
| """ | |||
| tenant_id: str | |||
| declaration: AgentStrategyEntity | |||
| def __init__(self, tenant_id: str, declaration: AgentStrategyEntity): | |||
| self.tenant_id = tenant_id | |||
| self.declaration = declaration | |||
| def get_parameters(self) -> Sequence[AgentStrategyParameter]: | |||
| return self.declaration.parameters | |||
| def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| Initialize the parameters for the agent strategy. | |||
| """ | |||
| for parameter in self.declaration.parameters: | |||
| params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name)) | |||
| return params | |||
| def _invoke( | |||
| self, | |||
| params: dict[str, Any], | |||
| user_id: str, | |||
| conversation_id: Optional[str] = None, | |||
| app_id: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| ) -> Generator[AgentInvokeMessage, None, None]: | |||
| """ | |||
| Invoke the agent strategy. | |||
| """ | |||
| manager = PluginAgentManager() | |||
| initialized_params = self.initialize_parameters(params) | |||
| params = convert_parameters_to_plugin_format(initialized_params) | |||
| yield from manager.invoke( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| agent_provider=self.declaration.identity.provider, | |||
| agent_strategy=self.declaration.identity.name, | |||
| agent_params=params, | |||
| conversation_id=conversation_id, | |||
| app_id=app_id, | |||
| message_id=message_id, | |||
| ) | |||
| @@ -4,7 +4,8 @@ from core.app.app_config.entities import EasyUIBasedAppConfig | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.entities.model_entities import ModelStatus | |||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.entities.llm_entities import LLMMode | |||
| from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.provider_manager import ProviderManager | |||
| @@ -63,14 +64,14 @@ class ModelConfigConverter: | |||
| stop = completion_params["stop"] | |||
| del completion_params["stop"] | |||
| model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) | |||
| # get model mode | |||
| model_mode = model_config.mode | |||
| if not model_mode: | |||
| mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials) | |||
| model_mode = mode_enum.value | |||
| model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) | |||
| model_mode = LLMMode.CHAT.value | |||
| if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE): | |||
| model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value | |||
| if not model_schema: | |||
| raise ValueError(f"Model {model_name} not exist.") | |||
| @@ -2,8 +2,9 @@ from collections.abc import Mapping | |||
| from typing import Any | |||
| from core.app.app_config.entities import ModelConfigEntity | |||
| from core.entities import DEFAULT_PLUGIN_ID | |||
| from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | |||
| from core.model_runtime.model_providers import model_provider_factory | |||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.provider_manager import ProviderManager | |||
| @@ -53,9 +54,18 @@ class ModelConfigManager: | |||
| raise ValueError("model must be of object type") | |||
| # model.provider | |||
| model_provider_factory = ModelProviderFactory(tenant_id) | |||
| provider_entities = model_provider_factory.get_providers() | |||
| model_provider_names = [provider.provider for provider in provider_entities] | |||
| if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: | |||
| if "provider" not in config["model"]: | |||
| raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") | |||
| if "/" not in config["model"]["provider"]: | |||
| config["model"]["provider"] = ( | |||
| f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}" | |||
| ) | |||
| if config["model"]["provider"] not in model_provider_names: | |||
| raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") | |||
| # model.name | |||
| @@ -45,8 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[True], | |||
| ) -> Generator[str, None, None]: ... | |||
| streaming: Literal[False], | |||
| ) -> Mapping[str, Any]: ... | |||
| @overload | |||
| def generate( | |||
| @@ -54,10 +54,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| args: Mapping, | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[False], | |||
| ) -> Mapping[str, Any]: ... | |||
| streaming: Literal[True], | |||
| ) -> Generator[Mapping | str, None, None]: ... | |||
| @overload | |||
| def generate( | |||
| @@ -65,20 +65,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| args: Mapping, | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... | |||
| streaming: bool, | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ... | |||
| def generate( | |||
| self, | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| args: Mapping, | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| ): | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: | |||
| """ | |||
| Generate App response. | |||
| @@ -154,6 +154,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| workflow_run_id=workflow_run_id, | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| return self._generate( | |||
| workflow=workflow, | |||
| @@ -165,8 +167,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| ) | |||
| def single_iteration_generate( | |||
| self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True | |||
| ) -> Mapping[str, Any] | Generator[str, None, None]: | |||
| self, | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user: Account | EndUser, | |||
| args: Mapping, | |||
| streaming: bool = True, | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: | |||
| """ | |||
| Generate App response. | |||
| @@ -203,6 +211,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| ), | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| return self._generate( | |||
| workflow=workflow, | |||
| @@ -222,7 +232,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| conversation: Optional[Conversation] = None, | |||
| stream: bool = True, | |||
| ) -> Mapping[str, Any] | Generator[str, None, None]: | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: | |||
| """ | |||
| Generate App response. | |||
| @@ -56,7 +56,7 @@ def _process_future( | |||
| class AppGeneratorTTSPublisher: | |||
| def __init__(self, tenant_id: str, voice: str): | |||
| def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None): | |||
| self.logger = logging.getLogger(__name__) | |||
| self.tenant_id = tenant_id | |||
| self.msg_text = "" | |||
| @@ -67,7 +67,7 @@ class AppGeneratorTTSPublisher: | |||
| self.model_instance = self.model_manager.get_default_model_instance( | |||
| tenant_id=self.tenant_id, model_type=ModelType.TTS | |||
| ) | |||
| self.voices = self.model_instance.get_tts_voices() | |||
| self.voices = self.model_instance.get_tts_voices(language=language) | |||
| values = [voice.get("value") for voice in self.voices] | |||
| self.voice = voice | |||
| if not voice or voice not in values: | |||
| @@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | |||
| workflow=workflow, | |||
| node_id=self.application_generate_entity.single_iteration_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_iteration_run.inputs, | |||
| user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), | |||
| ) | |||
| else: | |||
| inputs = self.application_generate_entity.inputs | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| from collections.abc import Generator | |||
| from typing import Any, cast | |||
| @@ -58,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_full_response( | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[str, Any, None]: | |||
| ) -> Generator[dict | str, Any, None]: | |||
| """ | |||
| Convert stream full response. | |||
| :param stream_response: stream response | |||
| @@ -84,12 +83,12 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield json.dumps(response_chunk) | |||
| yield response_chunk | |||
| @classmethod | |||
| def convert_stream_simple_response( | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[str, Any, None]: | |||
| ) -> Generator[dict | str, Any, None]: | |||
| """ | |||
| Convert stream simple response. | |||
| :param stream_response: stream response | |||
| @@ -123,4 +122,4 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield json.dumps(response_chunk) | |||
| yield response_chunk | |||
| @@ -17,6 +17,7 @@ from core.app.entities.app_invoke_entities import ( | |||
| ) | |||
| from core.app.entities.queue_entities import ( | |||
| QueueAdvancedChatMessageEndEvent, | |||
| QueueAgentLogEvent, | |||
| QueueAnnotationReplyEvent, | |||
| QueueErrorEvent, | |||
| QueueIterationCompletedEvent, | |||
| @@ -219,7 +220,9 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| and features_dict["text_to_speech"].get("enabled") | |||
| and features_dict["text_to_speech"].get("autoPlay") == "enabled" | |||
| ): | |||
| tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) | |||
| tts_publisher = AppGeneratorTTSPublisher( | |||
| tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language") | |||
| ) | |||
| for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): | |||
| while True: | |||
| @@ -247,7 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| else: | |||
| start_listener_time = time.time() | |||
| yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) | |||
| except Exception as e: | |||
| except Exception: | |||
| logger.exception(f"Failed to listen audio message, task_id: {task_id}") | |||
| break | |||
| if tts_publisher: | |||
| @@ -640,6 +643,10 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| session.commit() | |||
| yield self._message_end_to_stream_response() | |||
| elif isinstance(event, QueueAgentLogEvent): | |||
| yield self._workflow_cycle_manager._handle_agent_log( | |||
| task_id=self._application_generate_entity.task_id, event=event | |||
| ) | |||
| else: | |||
| continue | |||
| @@ -1,3 +1,4 @@ | |||
| import contextvars | |||
| import logging | |||
| import threading | |||
| import uuid | |||
| @@ -37,8 +38,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[True], | |||
| ) -> Generator[str, None, None]: ... | |||
| streaming: Literal[False], | |||
| ) -> Mapping[str, Any]: ... | |||
| @overload | |||
| def generate( | |||
| @@ -48,8 +49,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[False], | |||
| ) -> Mapping[str, Any]: ... | |||
| streaming: Literal[True], | |||
| ) -> Generator[Mapping | str, None, None]: ... | |||
| @overload | |||
| def generate( | |||
| @@ -60,7 +61,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool, | |||
| ) -> Mapping[str, Any] | Generator[str, None, None]: ... | |||
| ) -> Union[Mapping, Generator[Mapping | str, None, None]]: ... | |||
| def generate( | |||
| self, | |||
| @@ -70,7 +71,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| ): | |||
| ) -> Union[Mapping, Generator[Mapping | str, None, None]]: | |||
| """ | |||
| Generate App response. | |||
| @@ -180,6 +181,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| target=self._generate_worker, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "context": contextvars.copy_context(), | |||
| "application_generate_entity": application_generate_entity, | |||
| "queue_manager": queue_manager, | |||
| "conversation_id": conversation.id, | |||
| @@ -204,6 +206,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| def _generate_worker( | |||
| self, | |||
| flask_app: Flask, | |||
| context: contextvars.Context, | |||
| application_generate_entity: AgentChatAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| conversation_id: str, | |||
| @@ -218,6 +221,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| :param message_id: message ID | |||
| :return: | |||
| """ | |||
| for var, val in context.items(): | |||
| var.set(val) | |||
| with flask_app.app_context(): | |||
| try: | |||
| # get conversation and message | |||
| @@ -8,18 +8,16 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner | |||
| from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.apps.base_app_runner import AppRunner | |||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity | |||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity | |||
| from core.app.entities.queue_entities import QueueAnnotationReplyEvent | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage | |||
| from core.model_runtime.entities.llm_entities import LLMMode | |||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.moderation.base import ModerationError | |||
| from core.tools.entities.tool_entities import ToolRuntimeVariablePool | |||
| from extensions.ext_database import db | |||
| from models.model import App, Conversation, Message, MessageAgentThought | |||
| from models.tools import ToolConversationVariables | |||
| from models.model import App, Conversation, Message | |||
| logger = logging.getLogger(__name__) | |||
| @@ -64,8 +62,8 @@ class AgentChatAppRunner(AppRunner): | |||
| app_record=app_record, | |||
| model_config=application_generate_entity.model_conf, | |||
| prompt_template_entity=app_config.prompt_template, | |||
| inputs=inputs, | |||
| files=files, | |||
| inputs=dict(inputs), | |||
| files=list(files), | |||
| query=query, | |||
| ) | |||
| @@ -86,8 +84,8 @@ class AgentChatAppRunner(AppRunner): | |||
| app_record=app_record, | |||
| model_config=application_generate_entity.model_conf, | |||
| prompt_template_entity=app_config.prompt_template, | |||
| inputs=inputs, | |||
| files=files, | |||
| inputs=dict(inputs), | |||
| files=list(files), | |||
| query=query, | |||
| memory=memory, | |||
| ) | |||
| @@ -99,8 +97,8 @@ class AgentChatAppRunner(AppRunner): | |||
| app_id=app_record.id, | |||
| tenant_id=app_config.tenant_id, | |||
| app_generate_entity=application_generate_entity, | |||
| inputs=inputs, | |||
| query=query, | |||
| inputs=dict(inputs), | |||
| query=query or "", | |||
| message_id=message.id, | |||
| ) | |||
| except ModerationError as e: | |||
| @@ -156,9 +154,9 @@ class AgentChatAppRunner(AppRunner): | |||
| app_record=app_record, | |||
| model_config=application_generate_entity.model_conf, | |||
| prompt_template_entity=app_config.prompt_template, | |||
| inputs=inputs, | |||
| files=files, | |||
| query=query, | |||
| inputs=dict(inputs), | |||
| files=list(files), | |||
| query=query or "", | |||
| memory=memory, | |||
| ) | |||
| @@ -173,16 +171,7 @@ class AgentChatAppRunner(AppRunner): | |||
| return | |||
| agent_entity = app_config.agent | |||
| if not agent_entity: | |||
| raise ValueError("Agent entity not found") | |||
| # load tool variables | |||
| tool_conversation_variables = self._load_tool_variables( | |||
| conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id | |||
| ) | |||
| # convert db variables to tool variables | |||
| tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) | |||
| assert agent_entity is not None | |||
| # init model instance | |||
| model_instance = ModelInstance( | |||
| @@ -193,9 +182,9 @@ class AgentChatAppRunner(AppRunner): | |||
| app_record=app_record, | |||
| model_config=application_generate_entity.model_conf, | |||
| prompt_template_entity=app_config.prompt_template, | |||
| inputs=inputs, | |||
| files=files, | |||
| query=query, | |||
| inputs=dict(inputs), | |||
| files=list(files), | |||
| query=query or "", | |||
| memory=memory, | |||
| ) | |||
| @@ -243,8 +232,6 @@ class AgentChatAppRunner(AppRunner): | |||
| user_id=application_generate_entity.user_id, | |||
| memory=memory, | |||
| prompt_messages=prompt_message, | |||
| variables_pool=tool_variables, | |||
| db_variables=tool_conversation_variables, | |||
| model_instance=model_instance, | |||
| ) | |||
| @@ -261,73 +248,3 @@ class AgentChatAppRunner(AppRunner): | |||
| stream=application_generate_entity.stream, | |||
| agent=True, | |||
| ) | |||
| def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: | |||
| """ | |||
| load tool variables from database | |||
| """ | |||
| tool_variables: ToolConversationVariables | None = ( | |||
| db.session.query(ToolConversationVariables) | |||
| .filter( | |||
| ToolConversationVariables.conversation_id == conversation_id, | |||
| ToolConversationVariables.tenant_id == tenant_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if tool_variables: | |||
| # save tool variables to session, so that we can update it later | |||
| db.session.add(tool_variables) | |||
| else: | |||
| # create new tool variables | |||
| tool_variables = ToolConversationVariables( | |||
| conversation_id=conversation_id, | |||
| user_id=user_id, | |||
| tenant_id=tenant_id, | |||
| variables_str="[]", | |||
| ) | |||
| db.session.add(tool_variables) | |||
| db.session.commit() | |||
| return tool_variables | |||
| def _convert_db_variables_to_tool_variables( | |||
| self, db_variables: ToolConversationVariables | |||
| ) -> ToolRuntimeVariablePool: | |||
| """ | |||
| convert db variables to tool variables | |||
| """ | |||
| return ToolRuntimeVariablePool( | |||
| **{ | |||
| "conversation_id": db_variables.conversation_id, | |||
| "user_id": db_variables.user_id, | |||
| "tenant_id": db_variables.tenant_id, | |||
| "pool": db_variables.variables, | |||
| } | |||
| ) | |||
| def _get_usage_of_all_agent_thoughts( | |||
| self, model_config: ModelConfigWithCredentialsEntity, message: Message | |||
| ) -> LLMUsage: | |||
| """ | |||
| Get usage of all agent thoughts | |||
| :param model_config: model config | |||
| :param message: message | |||
| :return: | |||
| """ | |||
| agent_thoughts = ( | |||
| db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all() | |||
| ) | |||
| all_message_tokens = 0 | |||
| all_answer_tokens = 0 | |||
| for agent_thought in agent_thoughts: | |||
| all_message_tokens += agent_thought.message_tokens | |||
| all_answer_tokens += agent_thought.answer_tokens | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| return model_type_instance._calc_response_usage( | |||
| model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens | |||
| ) | |||
| @@ -1,9 +1,9 @@ | |||
| import json | |||
| from collections.abc import Generator | |||
| from typing import cast | |||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | |||
| from core.app.entities.task_entities import ( | |||
| AppStreamResponse, | |||
| ChatbotAppBlockingResponse, | |||
| ChatbotAppStreamResponse, | |||
| ErrorStreamResponse, | |||
| @@ -51,10 +51,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| return response | |||
| @classmethod | |||
| def convert_stream_full_response( # type: ignore[override] | |||
| cls, | |||
| stream_response: Generator[ChatbotAppStreamResponse, None, None], | |||
| ) -> Generator[str, None, None]: | |||
| def convert_stream_full_response( | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream full response. | |||
| :param stream_response: stream response | |||
| @@ -80,13 +79,12 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield json.dumps(response_chunk) | |||
| yield response_chunk | |||
| @classmethod | |||
| def convert_stream_simple_response( # type: ignore[override] | |||
| cls, | |||
| stream_response: Generator[ChatbotAppStreamResponse, None, None], | |||
| ) -> Generator[str, None, None]: | |||
| def convert_stream_simple_response( | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream simple response. | |||
| :param stream_response: stream response | |||
| @@ -118,4 +116,4 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield json.dumps(response_chunk) | |||
| yield response_chunk | |||
| @@ -14,21 +14,15 @@ class AppGenerateResponseConverter(ABC): | |||
| @classmethod | |||
| def convert( | |||
| cls, | |||
| response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], | |||
| invoke_from: InvokeFrom, | |||
| ) -> Mapping[str, Any] | Generator[str, None, None]: | |||
| cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: | |||
| if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: | |||
| if isinstance(response, AppBlockingResponse): | |||
| return cls.convert_blocking_full_response(response) | |||
| else: | |||
| def _generate_full_response() -> Generator[str, Any, None]: | |||
| for chunk in cls.convert_stream_full_response(response): | |||
| if chunk == "ping": | |||
| yield f"event: {chunk}\n\n" | |||
| else: | |||
| yield f"data: {chunk}\n\n" | |||
| def _generate_full_response() -> Generator[dict | str, Any, None]: | |||
| yield from cls.convert_stream_full_response(response) | |||
| return _generate_full_response() | |||
| else: | |||
| @@ -36,12 +30,8 @@ class AppGenerateResponseConverter(ABC): | |||
| return cls.convert_blocking_simple_response(response) | |||
| else: | |||
| def _generate_simple_response() -> Generator[str, Any, None]: | |||
| for chunk in cls.convert_stream_simple_response(response): | |||
| if chunk == "ping": | |||
| yield f"event: {chunk}\n\n" | |||
| else: | |||
| yield f"data: {chunk}\n\n" | |||
| def _generate_simple_response() -> Generator[dict | str, Any, None]: | |||
| yield from cls.convert_stream_simple_response(response) | |||
| return _generate_simple_response() | |||
| @@ -59,14 +49,14 @@ class AppGenerateResponseConverter(ABC): | |||
| @abstractmethod | |||
| def convert_stream_full_response( | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[str, None, None]: | |||
| ) -> Generator[dict | str, None, None]: | |||
| raise NotImplementedError | |||
| @classmethod | |||
| @abstractmethod | |||
| def convert_stream_simple_response( | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[str, None, None]: | |||
| ) -> Generator[dict | str, None, None]: | |||
| raise NotImplementedError | |||
| @classmethod | |||
| @@ -1,5 +1,6 @@ | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import TYPE_CHECKING, Any, Optional | |||
| import json | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||
| from core.app.app_config.entities import VariableEntityType | |||
| from core.file import File, FileUploadConfig | |||
| @@ -138,3 +139,21 @@ class BaseAppGenerator: | |||
| if isinstance(value, str): | |||
| return value.replace("\x00", "") | |||
| return value | |||
| @classmethod | |||
| def convert_to_event_stream(cls, generator: Union[Mapping, Generator[Mapping | str, None, None]]): | |||
| """ | |||
| Convert messages into event stream | |||
| """ | |||
| if isinstance(generator, dict): | |||
| return generator | |||
| else: | |||
| def gen(): | |||
| for message in generator: | |||
| if isinstance(message, (Mapping, dict)): | |||
| yield f"data: {json.dumps(message)}\n\n" | |||
| else: | |||
| yield f"event: {message}\n\n" | |||
| return gen() | |||
| @@ -2,7 +2,7 @@ import queue | |||
| import time | |||
| from abc import abstractmethod | |||
| from enum import Enum | |||
| from typing import Any | |||
| from typing import Any, Optional | |||
| from sqlalchemy.orm import DeclarativeMeta | |||
| @@ -115,7 +115,7 @@ class AppQueueManager: | |||
| Set task stop flag | |||
| :return: | |||
| """ | |||
| result = redis_client.get(cls._generate_task_belong_cache_key(task_id)) | |||
| result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id)) | |||
| if result is None: | |||
| return | |||
| @@ -38,7 +38,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[True], | |||
| ) -> Generator[str, None, None]: ... | |||
| ) -> Generator[Mapping | str, None, None]: ... | |||
| @overload | |||
| def generate( | |||
| @@ -58,7 +58,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool, | |||
| ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ... | |||
| def generate( | |||
| self, | |||
| @@ -67,7 +67,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| ): | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: | |||
| """ | |||
| Generate App response. | |||
| @@ -1,9 +1,9 @@ | |||
| import json | |||
| from collections.abc import Generator | |||
| from typing import cast | |||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | |||
| from core.app.entities.task_entities import ( | |||
| AppStreamResponse, | |||
| ChatbotAppBlockingResponse, | |||
| ChatbotAppStreamResponse, | |||
| ErrorStreamResponse, | |||
| @@ -52,9 +52,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_full_response( | |||
| cls, | |||
| stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] | |||
| ) -> Generator[str, None, None]: | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream full response. | |||
| :param stream_response: stream response | |||
| @@ -80,13 +79,12 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield json.dumps(response_chunk) | |||
| yield response_chunk | |||
| @classmethod | |||
| def convert_stream_simple_response( | |||
| cls, | |||
| stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] | |||
| ) -> Generator[str, None, None]: | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream simple response. | |||
| :param stream_response: stream response | |||
| @@ -118,4 +116,4 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield json.dumps(response_chunk) | |||
| yield response_chunk | |||
| @@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[True], | |||
| ) -> Generator[str, None, None]: ... | |||
| ) -> Generator[str | Mapping[str, Any], None, None]: ... | |||
| @overload | |||
| def generate( | |||
| @@ -56,8 +56,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool, | |||
| ) -> Mapping[str, Any] | Generator[str, None, None]: ... | |||
| streaming: bool = False, | |||
| ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ... | |||
| def generate( | |||
| self, | |||
| @@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| ): | |||
| ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: | |||
| """ | |||
| Generate App response. | |||
| @@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| user: Union[Account, EndUser], | |||
| invoke_from: InvokeFrom, | |||
| stream: bool = True, | |||
| ) -> Union[Mapping[str, Any], Generator[str, None, None]]: | |||
| ) -> Union[Mapping, Generator[Mapping | str, None, None]]: | |||
| """ | |||
| Generate App response. | |||
| @@ -1,9 +1,9 @@ | |||
| import json | |||
| from collections.abc import Generator | |||
| from typing import cast | |||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | |||
| from core.app.entities.task_entities import ( | |||
| AppStreamResponse, | |||
| CompletionAppBlockingResponse, | |||
| CompletionAppStreamResponse, | |||
| ErrorStreamResponse, | |||
| @@ -51,9 +51,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_full_response( | |||
| cls, | |||
| stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] | |||
| ) -> Generator[str, None, None]: | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream full response. | |||
| :param stream_response: stream response | |||
| @@ -78,13 +77,12 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield json.dumps(response_chunk) | |||
| yield response_chunk | |||
| @classmethod | |||
| def convert_stream_simple_response( | |||
| cls, | |||
| stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] | |||
| ) -> Generator[str, None, None]: | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream simple response. | |||
| :param stream_response: stream response | |||
| @@ -115,4 +113,4 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield json.dumps(response_chunk) | |||
| yield response_chunk | |||
| @@ -36,13 +36,13 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| *, | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| user: Account | EndUser, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[True], | |||
| call_depth: int = 0, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> Generator[str, None, None]: ... | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Generator[Mapping | str, None, None]: ... | |||
| @overload | |||
| def generate( | |||
| @@ -50,12 +50,12 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| *, | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| user: Account | EndUser, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[False], | |||
| call_depth: int = 0, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Mapping[str, Any]: ... | |||
| @overload | |||
| @@ -64,26 +64,26 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| *, | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| user: Account | EndUser, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| call_depth: int = 0, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> Mapping[str, Any] | Generator[str, None, None]: ... | |||
| streaming: bool, | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... | |||
| def generate( | |||
| self, | |||
| *, | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| user: Account | EndUser, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| call_depth: int = 0, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ): | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: | |||
| files: Sequence[Mapping[str, Any]] = args.get("files") or [] | |||
| # parse files | |||
| @@ -124,7 +124,10 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| trace_manager=trace_manager, | |||
| workflow_run_id=workflow_run_id, | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| return self._generate( | |||
| app_model=app_model, | |||
| @@ -146,7 +149,18 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> Mapping[str, Any] | Generator[str, None, None]: | |||
| ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: | |||
| """ | |||
| Generate App response. | |||
| :param app_model: App | |||
| :param workflow: Workflow | |||
| :param user: account or end user | |||
| :param application_generate_entity: application generate entity | |||
| :param invoke_from: invoke from source | |||
| :param stream: is stream | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| # init queue manager | |||
| queue_manager = WorkflowAppQueueManager( | |||
| task_id=application_generate_entity.task_id, | |||
| @@ -185,10 +199,10 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user: Account, | |||
| user: Account | EndUser, | |||
| args: Mapping[str, Any], | |||
| streaming: bool = True, | |||
| ) -> Mapping[str, Any] | Generator[str, None, None]: | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: | |||
| """ | |||
| Generate App response. | |||
| @@ -224,6 +238,8 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow_run_id=str(uuid.uuid4()), | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| return self._generate( | |||
| app_model=app_model, | |||
| @@ -1,9 +1,9 @@ | |||
| import json | |||
| from collections.abc import Generator | |||
| from typing import cast | |||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | |||
| from core.app.entities.task_entities import ( | |||
| AppStreamResponse, | |||
| ErrorStreamResponse, | |||
| NodeFinishStreamResponse, | |||
| NodeStartStreamResponse, | |||
| @@ -36,9 +36,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_full_response( | |||
| cls, | |||
| stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] | |||
| ) -> Generator[str, None, None]: | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream full response. | |||
| :param stream_response: stream response | |||
| @@ -62,13 +61,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield json.dumps(response_chunk) | |||
| yield response_chunk | |||
| @classmethod | |||
| def convert_stream_simple_response( | |||
| cls, | |||
| stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] | |||
| ) -> Generator[str, None, None]: | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream simple response. | |||
| :param stream_response: stream response | |||
| @@ -94,4 +92,4 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| response_chunk.update(sub_stream_response.to_ignore_detail_dict()) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield json.dumps(response_chunk) | |||
| yield response_chunk | |||
| @@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import ( | |||
| WorkflowAppGenerateEntity, | |||
| ) | |||
| from core.app.entities.queue_entities import ( | |||
| QueueAgentLogEvent, | |||
| QueueErrorEvent, | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| @@ -190,7 +191,9 @@ class WorkflowAppGenerateTaskPipeline: | |||
| and features_dict["text_to_speech"].get("enabled") | |||
| and features_dict["text_to_speech"].get("autoPlay") == "enabled" | |||
| ): | |||
| tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) | |||
| tts_publisher = AppGeneratorTTSPublisher( | |||
| tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language") | |||
| ) | |||
| for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): | |||
| while True: | |||
| @@ -527,6 +530,10 @@ class WorkflowAppGenerateTaskPipeline: | |||
| yield self._text_chunk_to_stream_response( | |||
| delta_text, from_variable_selector=event.from_variable_selector | |||
| ) | |||
| elif isinstance(event, QueueAgentLogEvent): | |||
| yield self._workflow_cycle_manager._handle_agent_log( | |||
| task_id=self._application_generate_entity.task_id, event=event | |||
| ) | |||
| else: | |||
| continue | |||
| @@ -5,6 +5,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.apps.base_app_runner import AppRunner | |||
| from core.app.entities.queue_entities import ( | |||
| AppQueueEvent, | |||
| QueueAgentLogEvent, | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| @@ -27,6 +28,7 @@ from core.app.entities.queue_entities import ( | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| AgentLogEvent, | |||
| GraphEngineEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunPartialSucceededEvent, | |||
| @@ -239,6 +241,7 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| in_iteration_id=event.in_iteration_id, | |||
| parallel_mode_run_id=event.parallel_mode_run_id, | |||
| agent_strategy=event.agent_strategy, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunSucceededEvent): | |||
| @@ -373,6 +376,19 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id | |||
| ) | |||
| ) | |||
| elif isinstance(event, AgentLogEvent): | |||
| self._publish_event( | |||
| QueueAgentLogEvent( | |||
| id=event.id, | |||
| label=event.label, | |||
| node_execution_id=event.node_execution_id, | |||
| parent_id=event.parent_id, | |||
| error=event.error, | |||
| status=event.status, | |||
| data=event.data, | |||
| metadata=event.metadata, | |||
| ) | |||
| ) | |||
| elif isinstance(event, ParallelBranchRunStartedEvent): | |||
| self._publish_event( | |||
| QueueParallelBranchRunStartedEvent( | |||
| @@ -183,7 +183,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): | |||
| """ | |||
| node_id: str | |||
| inputs: dict | |||
| inputs: Mapping | |||
| single_iteration_run: Optional[SingleIterationRunEntity] = None | |||
| @@ -6,7 +6,7 @@ from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.base import BaseNodeData | |||
| @@ -41,6 +41,7 @@ class QueueEvent(StrEnum): | |||
| PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" | |||
| PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" | |||
| PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" | |||
| AGENT_LOG = "agent_log" | |||
| ERROR = "error" | |||
| PING = "ping" | |||
| STOP = "stop" | |||
| @@ -280,6 +281,7 @@ class QueueNodeStartedEvent(AppQueueEvent): | |||
| start_at: datetime | |||
| parallel_mode_run_id: Optional[str] = None | |||
| """iteratoin run in parallel mode run id""" | |||
| agent_strategy: Optional[AgentNodeStrategyInit] = None | |||
| class QueueNodeSucceededEvent(AppQueueEvent): | |||
| @@ -315,6 +317,22 @@ class QueueNodeSucceededEvent(AppQueueEvent): | |||
| iteration_duration_map: Optional[dict[str, float]] = None | |||
| class QueueAgentLogEvent(AppQueueEvent): | |||
| """ | |||
| QueueAgentLogEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.AGENT_LOG | |||
| id: str | |||
| label: str | |||
| node_execution_id: str | |||
| parent_id: str | None | |||
| error: str | None | |||
| status: str | |||
| data: Mapping[str, Any] | |||
| metadata: Optional[Mapping[str, Any]] = None | |||
| class QueueNodeRetryEvent(QueueNodeStartedEvent): | |||
| """QueueNodeRetryEvent entity""" | |||
| @@ -6,6 +6,7 @@ from pydantic import BaseModel, ConfigDict | |||
| from core.model_runtime.entities.llm_entities import LLMResult | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| @@ -60,6 +61,7 @@ class StreamEvent(Enum): | |||
| ITERATION_COMPLETED = "iteration_completed" | |||
| TEXT_CHUNK = "text_chunk" | |||
| TEXT_REPLACE = "text_replace" | |||
| AGENT_LOG = "agent_log" | |||
| class StreamResponse(BaseModel): | |||
| @@ -247,6 +249,7 @@ class NodeStartStreamResponse(StreamResponse): | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| iteration_id: Optional[str] = None | |||
| parallel_run_id: Optional[str] = None | |||
| agent_strategy: Optional[AgentNodeStrategyInit] = None | |||
| event: StreamEvent = StreamEvent.NODE_STARTED | |||
| workflow_run_id: str | |||
| @@ -696,3 +699,26 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): | |||
| workflow_run_id: str | |||
| data: Data | |||
| class AgentLogStreamResponse(StreamResponse): | |||
| """ | |||
| AgentLogStreamResponse entity | |||
| """ | |||
| class Data(BaseModel): | |||
| """ | |||
| Data entity | |||
| """ | |||
| node_execution_id: str | |||
| id: str | |||
| label: str | |||
| parent_id: str | None | |||
| error: str | None | |||
| status: str | |||
| data: Mapping[str, Any] | |||
| metadata: Optional[Mapping[str, Any]] = None | |||
| event: StreamEvent = StreamEvent.AGENT_LOG | |||
| data: Data | |||
| @@ -24,6 +24,8 @@ class HostingModerationFeature: | |||
| if isinstance(prompt_message.content, str): | |||
| text += prompt_message.content + "\n" | |||
| moderation_result = moderation.check_moderation(model_config, text) | |||
| moderation_result = moderation.check_moderation( | |||
| tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text | |||
| ) | |||
| return moderation_result | |||
| @@ -215,7 +215,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| and text_to_speech_dict.get("autoPlay") == "enabled" | |||
| and text_to_speech_dict.get("enabled") | |||
| ): | |||
| publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) | |||
| publisher = AppGeneratorTTSPublisher( | |||
| tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None) | |||
| ) | |||
| for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): | |||
| while True: | |||
| audio_response = self._listen_audio_msg(publisher, task_id) | |||
| @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity | |||
| from core.app.entities.queue_entities import ( | |||
| QueueAgentLogEvent, | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| @@ -24,6 +25,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueParallelBranchRunSucceededEvent, | |||
| ) | |||
| from core.app.entities.task_entities import ( | |||
| AgentLogStreamResponse, | |||
| IterationNodeCompletedStreamResponse, | |||
| IterationNodeNextStreamResponse, | |||
| IterationNodeStartStreamResponse, | |||
| @@ -320,9 +322,8 @@ class WorkflowCycleManage: | |||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | |||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | |||
| execution_metadata = ( | |||
| json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None | |||
| ) | |||
| execution_metadata_dict = dict(event.execution_metadata or {}) | |||
| execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None | |||
| finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| elapsed_time = (finished_at - event.start_at).total_seconds() | |||
| @@ -540,6 +541,7 @@ class WorkflowCycleManage: | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| iteration_id=event.in_iteration_id, | |||
| parallel_run_id=event.parallel_mode_run_id, | |||
| agent_strategy=event.agent_strategy, | |||
| ), | |||
| ) | |||
| @@ -843,3 +845,24 @@ class WorkflowCycleManage: | |||
| raise ValueError(f"Workflow node execution not found: {node_execution_id}") | |||
| cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] | |||
| return session.merge(cached_workflow_node_execution) | |||
| def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: | |||
| """ | |||
| Handle agent log | |||
| :param task_id: task id | |||
| :param event: agent log event | |||
| :return: | |||
| """ | |||
| return AgentLogStreamResponse( | |||
| task_id=task_id, | |||
| data=AgentLogStreamResponse.Data( | |||
| node_execution_id=event.node_execution_id, | |||
| id=event.id, | |||
| parent_id=event.parent_id, | |||
| label=event.label, | |||
| error=event.error, | |||
| status=event.status, | |||
| data=event.data, | |||
| metadata=event.metadata, | |||
| ), | |||
| ) | |||
| @@ -1,4 +1,4 @@ | |||
| from collections.abc import Mapping, Sequence | |||
| from collections.abc import Iterable, Mapping | |||
| from typing import Any, Optional, TextIO, Union | |||
| from pydantic import BaseModel | |||
| @@ -57,7 +57,7 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| self, | |||
| tool_name: str, | |||
| tool_inputs: Mapping[str, Any], | |||
| tool_outputs: Sequence[ToolInvokeMessage] | str, | |||
| tool_outputs: Iterable[ToolInvokeMessage] | str, | |||
| message_id: Optional[str] = None, | |||
| timer: Optional[Any] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| @@ -1,5 +1,26 @@ | |||
| from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler | |||
| from collections.abc import Generator, Iterable, Mapping | |||
| from typing import Any, Optional | |||
| from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): | |||
| """Callback Handler that prints to std out.""" | |||
| def on_tool_execution( | |||
| self, | |||
| tool_name: str, | |||
| tool_inputs: Mapping[str, Any], | |||
| tool_outputs: Iterable[ToolInvokeMessage], | |||
| message_id: Optional[str] = None, | |||
| timer: Optional[Any] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> Generator[ToolInvokeMessage, None, None]: | |||
| for tool_output in tool_outputs: | |||
| print_text("\n[on_tool_execution]\n", color=self.color) | |||
| print_text("Tool: " + tool_name + "\n", color=self.color) | |||
| print_text("Outputs: " + tool_output.model_dump_json()[:1000] + "\n", color=self.color) | |||
| print_text("\n") | |||
| yield tool_output | |||
| @@ -0,0 +1 @@ | |||
| DEFAULT_PLUGIN_ID = "langgenius" | |||
| @@ -0,0 +1,42 @@ | |||
| from enum import StrEnum | |||
| class CommonParameterType(StrEnum): | |||
| SECRET_INPUT = "secret-input" | |||
| TEXT_INPUT = "text-input" | |||
| SELECT = "select" | |||
| STRING = "string" | |||
| NUMBER = "number" | |||
| FILE = "file" | |||
| FILES = "files" | |||
| SYSTEM_FILES = "system-files" | |||
| BOOLEAN = "boolean" | |||
| APP_SELECTOR = "app-selector" | |||
| MODEL_SELECTOR = "model-selector" | |||
| TOOLS_SELECTOR = "array[tools]" | |||
| # TOOL_SELECTOR = "tool-selector" | |||
| class AppSelectorScope(StrEnum): | |||
| ALL = "all" | |||
| CHAT = "chat" | |||
| WORKFLOW = "workflow" | |||
| COMPLETION = "completion" | |||
| class ModelSelectorScope(StrEnum): | |||
| LLM = "llm" | |||
| TEXT_EMBEDDING = "text-embedding" | |||
| RERANK = "rerank" | |||
| TTS = "tts" | |||
| SPEECH2TEXT = "speech2text" | |||
| MODERATION = "moderation" | |||
| VISION = "vision" | |||
| class ToolSelectorScope(StrEnum): | |||
| ALL = "all" | |||
| CUSTOM = "custom" | |||
| BUILTIN = "builtin" | |||
| WORKFLOW = "workflow" | |||
| @@ -2,13 +2,14 @@ import datetime | |||
| import json | |||
| import logging | |||
| from collections import defaultdict | |||
| from collections.abc import Iterator | |||
| from collections.abc import Iterator, Sequence | |||
| from json import JSONDecodeError | |||
| from typing import Optional | |||
| from pydantic import BaseModel, ConfigDict | |||
| from constants import HIDDEN_VALUE | |||
| from core.entities import DEFAULT_PLUGIN_ID | |||
| from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity | |||
| from core.entities.provider_entities import ( | |||
| CustomConfiguration, | |||
| @@ -18,16 +19,15 @@ from core.entities.provider_entities import ( | |||
| ) | |||
| from core.helper import encrypter | |||
| from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | |||
| from core.model_runtime.entities.model_entities import FetchFrom, ModelType | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType | |||
| from core.model_runtime.entities.provider_entities import ( | |||
| ConfigurateMethod, | |||
| CredentialFormSchema, | |||
| FormType, | |||
| ProviderEntity, | |||
| ) | |||
| from core.model_runtime.model_providers import model_provider_factory | |||
| from core.model_runtime.model_providers.__base.ai_model import AIModel | |||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | |||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||
| from extensions.ext_database import db | |||
| from models.provider import ( | |||
| LoadBalancingModelConfig, | |||
| @@ -99,9 +99,10 @@ class ProviderConfiguration(BaseModel): | |||
| continue | |||
| restrict_models = quota_configuration.restrict_models | |||
| if self.system_configuration.credentials is None: | |||
| return None | |||
| copy_credentials = self.system_configuration.credentials.copy() | |||
| copy_credentials = ( | |||
| self.system_configuration.credentials.copy() if self.system_configuration.credentials else {} | |||
| ) | |||
| if restrict_models: | |||
| for restrict_model in restrict_models: | |||
| if ( | |||
| @@ -140,6 +141,9 @@ class ProviderConfiguration(BaseModel): | |||
| if current_quota_configuration is None: | |||
| return None | |||
| if not current_quota_configuration: | |||
| return SystemConfigurationStatus.UNSUPPORTED | |||
| return ( | |||
| SystemConfigurationStatus.ACTIVE | |||
| if current_quota_configuration.is_valid | |||
| @@ -153,7 +157,7 @@ class ProviderConfiguration(BaseModel): | |||
| """ | |||
| return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 | |||
| def get_custom_credentials(self, obfuscated: bool = False): | |||
| def get_custom_credentials(self, obfuscated: bool = False) -> dict | None: | |||
| """ | |||
| Get custom credentials. | |||
| @@ -175,7 +179,7 @@ class ProviderConfiguration(BaseModel): | |||
| else [], | |||
| ) | |||
| def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]: | |||
| def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: | |||
| """ | |||
| Validate custom credentials. | |||
| :param credentials: provider credentials | |||
| @@ -219,6 +223,7 @@ class ProviderConfiguration(BaseModel): | |||
| if value == HIDDEN_VALUE and key in original_credentials: | |||
| credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | |||
| model_provider_factory = ModelProviderFactory(self.tenant_id) | |||
| credentials = model_provider_factory.provider_credentials_validate( | |||
| provider=self.provider.provider, credentials=credentials | |||
| ) | |||
| @@ -246,13 +251,13 @@ class ProviderConfiguration(BaseModel): | |||
| provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| else: | |||
| provider_record = Provider( | |||
| tenant_id=self.tenant_id, | |||
| provider_name=self.provider.provider, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps(credentials), | |||
| is_valid=True, | |||
| ) | |||
| provider_record = Provider() | |||
| provider_record.tenant_id = self.tenant_id | |||
| provider_record.provider_name = self.provider.provider | |||
| provider_record.provider_type = ProviderType.CUSTOM.value | |||
| provider_record.encrypted_config = json.dumps(credentials) | |||
| provider_record.is_valid = True | |||
| db.session.add(provider_record) | |||
| db.session.commit() | |||
| @@ -327,7 +332,7 @@ class ProviderConfiguration(BaseModel): | |||
| def custom_model_credentials_validate( | |||
| self, model_type: ModelType, model: str, credentials: dict | |||
| ) -> tuple[Optional[ProviderModel], dict]: | |||
| ) -> tuple[ProviderModel | None, dict]: | |||
| """ | |||
| Validate custom model credentials. | |||
| @@ -370,6 +375,7 @@ class ProviderConfiguration(BaseModel): | |||
| if value == HIDDEN_VALUE and key in original_credentials: | |||
| credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | |||
| model_provider_factory = ModelProviderFactory(self.tenant_id) | |||
| credentials = model_provider_factory.model_credentials_validate( | |||
| provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials | |||
| ) | |||
| @@ -400,14 +406,13 @@ class ProviderConfiguration(BaseModel): | |||
| provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| else: | |||
| provider_model_record = ProviderModel( | |||
| tenant_id=self.tenant_id, | |||
| provider_name=self.provider.provider, | |||
| model_name=model, | |||
| model_type=model_type.to_origin_model_type(), | |||
| encrypted_config=json.dumps(credentials), | |||
| is_valid=True, | |||
| ) | |||
| provider_model_record = ProviderModel() | |||
| provider_model_record.tenant_id = self.tenant_id | |||
| provider_model_record.provider_name = self.provider.provider | |||
| provider_model_record.model_name = model | |||
| provider_model_record.model_type = model_type.to_origin_model_type() | |||
| provider_model_record.encrypted_config = json.dumps(credentials) | |||
| provider_model_record.is_valid = True | |||
| db.session.add(provider_model_record) | |||
| db.session.commit() | |||
| @@ -474,13 +479,12 @@ class ProviderConfiguration(BaseModel): | |||
| model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| else: | |||
| model_setting = ProviderModelSetting( | |||
| tenant_id=self.tenant_id, | |||
| provider_name=self.provider.provider, | |||
| model_type=model_type.to_origin_model_type(), | |||
| model_name=model, | |||
| enabled=True, | |||
| ) | |||
| model_setting = ProviderModelSetting() | |||
| model_setting.tenant_id = self.tenant_id | |||
| model_setting.provider_name = self.provider.provider | |||
| model_setting.model_type = model_type.to_origin_model_type() | |||
| model_setting.model_name = model | |||
| model_setting.enabled = True | |||
| db.session.add(model_setting) | |||
| db.session.commit() | |||
| @@ -509,13 +513,12 @@ class ProviderConfiguration(BaseModel): | |||
| model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| else: | |||
| model_setting = ProviderModelSetting( | |||
| tenant_id=self.tenant_id, | |||
| provider_name=self.provider.provider, | |||
| model_type=model_type.to_origin_model_type(), | |||
| model_name=model, | |||
| enabled=False, | |||
| ) | |||
| model_setting = ProviderModelSetting() | |||
| model_setting.tenant_id = self.tenant_id | |||
| model_setting.provider_name = self.provider.provider | |||
| model_setting.model_type = model_type.to_origin_model_type() | |||
| model_setting.model_name = model | |||
| model_setting.enabled = False | |||
| db.session.add(model_setting) | |||
| db.session.commit() | |||
| @@ -576,13 +579,12 @@ class ProviderConfiguration(BaseModel): | |||
| model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| else: | |||
| model_setting = ProviderModelSetting( | |||
| tenant_id=self.tenant_id, | |||
| provider_name=self.provider.provider, | |||
| model_type=model_type.to_origin_model_type(), | |||
| model_name=model, | |||
| load_balancing_enabled=True, | |||
| ) | |||
| model_setting = ProviderModelSetting() | |||
| model_setting.tenant_id = self.tenant_id | |||
| model_setting.provider_name = self.provider.provider | |||
| model_setting.model_type = model_type.to_origin_model_type() | |||
| model_setting.model_name = model | |||
| model_setting.load_balancing_enabled = True | |||
| db.session.add(model_setting) | |||
| db.session.commit() | |||
| @@ -611,25 +613,17 @@ class ProviderConfiguration(BaseModel): | |||
| model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| else: | |||
| model_setting = ProviderModelSetting( | |||
| tenant_id=self.tenant_id, | |||
| provider_name=self.provider.provider, | |||
| model_type=model_type.to_origin_model_type(), | |||
| model_name=model, | |||
| load_balancing_enabled=False, | |||
| ) | |||
| model_setting = ProviderModelSetting() | |||
| model_setting.tenant_id = self.tenant_id | |||
| model_setting.provider_name = self.provider.provider | |||
| model_setting.model_type = model_type.to_origin_model_type() | |||
| model_setting.model_name = model | |||
| model_setting.load_balancing_enabled = False | |||
| db.session.add(model_setting) | |||
| db.session.commit() | |||
| return model_setting | |||
| def get_provider_instance(self) -> ModelProvider: | |||
| """ | |||
| Get provider instance. | |||
| :return: | |||
| """ | |||
| return model_provider_factory.get_provider_instance(self.provider.provider) | |||
| def get_model_type_instance(self, model_type: ModelType) -> AIModel: | |||
| """ | |||
| Get current model type instance. | |||
| @@ -637,11 +631,19 @@ class ProviderConfiguration(BaseModel): | |||
| :param model_type: model type | |||
| :return: | |||
| """ | |||
| # Get provider instance | |||
| provider_instance = self.get_provider_instance() | |||
| model_provider_factory = ModelProviderFactory(self.tenant_id) | |||
| # Get model instance of LLM | |||
| return provider_instance.get_model_instance(model_type) | |||
| return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) | |||
| def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None: | |||
| """ | |||
| Get model schema | |||
| """ | |||
| model_provider_factory = ModelProviderFactory(self.tenant_id) | |||
| return model_provider_factory.get_model_schema( | |||
| provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials | |||
| ) | |||
| def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: | |||
| """ | |||
| @@ -668,11 +670,10 @@ class ProviderConfiguration(BaseModel): | |||
| if preferred_model_provider: | |||
| preferred_model_provider.preferred_provider_type = provider_type.value | |||
| else: | |||
| preferred_model_provider = TenantPreferredModelProvider( | |||
| tenant_id=self.tenant_id, | |||
| provider_name=self.provider.provider, | |||
| preferred_provider_type=provider_type.value, | |||
| ) | |||
| preferred_model_provider = TenantPreferredModelProvider() | |||
| preferred_model_provider.tenant_id = self.tenant_id | |||
| preferred_model_provider.provider_name = self.provider.provider | |||
| preferred_model_provider.preferred_provider_type = provider_type.value | |||
| db.session.add(preferred_model_provider) | |||
| db.session.commit() | |||
| @@ -737,13 +738,14 @@ class ProviderConfiguration(BaseModel): | |||
| :param only_active: only active models | |||
| :return: | |||
| """ | |||
| provider_instance = self.get_provider_instance() | |||
| model_provider_factory = ModelProviderFactory(self.tenant_id) | |||
| provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) | |||
| model_types = [] | |||
| model_types: list[ModelType] = [] | |||
| if model_type: | |||
| model_types.append(model_type) | |||
| else: | |||
| model_types = list(provider_instance.get_provider_schema().supported_model_types) | |||
| model_types = list(provider_schema.supported_model_types) | |||
| # Group model settings by model type and model | |||
| model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) | |||
| @@ -752,11 +754,11 @@ class ProviderConfiguration(BaseModel): | |||
| if self.using_provider_type == ProviderType.SYSTEM: | |||
| provider_models = self._get_system_provider_models( | |||
| model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map | |||
| model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map | |||
| ) | |||
| else: | |||
| provider_models = self._get_custom_provider_models( | |||
| model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map | |||
| model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map | |||
| ) | |||
| if only_active: | |||
| @@ -767,23 +769,26 @@ class ProviderConfiguration(BaseModel): | |||
| def _get_system_provider_models( | |||
| self, | |||
| model_types: list[ModelType], | |||
| provider_instance: ModelProvider, | |||
| model_types: Sequence[ModelType], | |||
| provider_schema: ProviderEntity, | |||
| model_setting_map: dict[ModelType, dict[str, ModelSettings]], | |||
| ) -> list[ModelWithProviderEntity]: | |||
| """ | |||
| Get system provider models. | |||
| :param model_types: model types | |||
| :param provider_instance: provider instance | |||
| :param provider_schema: provider schema | |||
| :param model_setting_map: model setting map | |||
| :return: | |||
| """ | |||
| provider_models = [] | |||
| for model_type in model_types: | |||
| for m in provider_instance.models(model_type): | |||
| for m in provider_schema.models: | |||
| if m.model_type != model_type: | |||
| continue | |||
| status = ModelStatus.ACTIVE | |||
| if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: | |||
| if m.model in model_setting_map: | |||
| model_setting = model_setting_map[m.model_type][m.model] | |||
| if model_setting.enabled is False: | |||
| status = ModelStatus.DISABLED | |||
| @@ -804,7 +809,7 @@ class ProviderConfiguration(BaseModel): | |||
| if self.provider.provider not in original_provider_configurate_methods: | |||
| original_provider_configurate_methods[self.provider.provider] = [] | |||
| for configurate_method in provider_instance.get_provider_schema().configurate_methods: | |||
| for configurate_method in provider_schema.configurate_methods: | |||
| original_provider_configurate_methods[self.provider.provider].append(configurate_method) | |||
| should_use_custom_model = False | |||
| @@ -825,18 +830,22 @@ class ProviderConfiguration(BaseModel): | |||
| ]: | |||
| # only customizable model | |||
| for restrict_model in restrict_models: | |||
| if self.system_configuration.credentials is not None: | |||
| copy_credentials = self.system_configuration.credentials.copy() | |||
| if restrict_model.base_model_name: | |||
| copy_credentials["base_model_name"] = restrict_model.base_model_name | |||
| try: | |||
| custom_model_schema = provider_instance.get_model_instance( | |||
| restrict_model.model_type | |||
| ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) | |||
| except Exception as ex: | |||
| logger.warning(f"get custom model schema failed, {ex}") | |||
| continue | |||
| copy_credentials = ( | |||
| self.system_configuration.credentials.copy() | |||
| if self.system_configuration.credentials | |||
| else {} | |||
| ) | |||
| if restrict_model.base_model_name: | |||
| copy_credentials["base_model_name"] = restrict_model.base_model_name | |||
| try: | |||
| custom_model_schema = self.get_model_schema( | |||
| model_type=restrict_model.model_type, | |||
| model=restrict_model.model, | |||
| credentials=copy_credentials, | |||
| ) | |||
| except Exception as ex: | |||
| logger.warning(f"get custom model schema failed, {ex}") | |||
| if not custom_model_schema: | |||
| continue | |||
| @@ -881,15 +890,15 @@ class ProviderConfiguration(BaseModel): | |||
| def _get_custom_provider_models( | |||
| self, | |||
| model_types: list[ModelType], | |||
| provider_instance: ModelProvider, | |||
| model_types: Sequence[ModelType], | |||
| provider_schema: ProviderEntity, | |||
| model_setting_map: dict[ModelType, dict[str, ModelSettings]], | |||
| ) -> list[ModelWithProviderEntity]: | |||
| """ | |||
| Get custom provider models. | |||
| :param model_types: model types | |||
| :param provider_instance: provider instance | |||
| :param provider_schema: provider schema | |||
| :param model_setting_map: model setting map | |||
| :return: | |||
| """ | |||
| @@ -903,8 +912,10 @@ class ProviderConfiguration(BaseModel): | |||
| if model_type not in self.provider.supported_model_types: | |||
| continue | |||
| models = provider_instance.models(model_type) | |||
| for m in models: | |||
| for m in provider_schema.models: | |||
| if m.model_type != model_type: | |||
| continue | |||
| status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE | |||
| load_balancing_enabled = False | |||
| if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: | |||
| @@ -936,10 +947,10 @@ class ProviderConfiguration(BaseModel): | |||
| continue | |||
| try: | |||
| custom_model_schema = provider_instance.get_model_instance( | |||
| model_configuration.model_type | |||
| ).get_customizable_model_schema_from_credentials( | |||
| model_configuration.model, model_configuration.credentials | |||
| custom_model_schema = self.get_model_schema( | |||
| model_type=model_configuration.model_type, | |||
| model=model_configuration.model, | |||
| credentials=model_configuration.credentials, | |||
| ) | |||
| except Exception as ex: | |||
| logger.warning(f"get custom model schema failed, {ex}") | |||
| @@ -967,7 +978,7 @@ class ProviderConfiguration(BaseModel): | |||
| label=custom_model_schema.label, | |||
| model_type=custom_model_schema.model_type, | |||
| features=custom_model_schema.features, | |||
| fetch_from=custom_model_schema.fetch_from, | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties=custom_model_schema.model_properties, | |||
| deprecated=custom_model_schema.deprecated, | |||
| provider=SimpleModelProviderEntity(self.provider), | |||
| @@ -1040,6 +1051,9 @@ class ProviderConfigurations(BaseModel): | |||
| return list(self.values()) | |||
| def __getitem__(self, key): | |||
| if "/" not in key: | |||
| key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" | |||
| return self.configurations[key] | |||
| def __setitem__(self, key, value): | |||
| @@ -1051,8 +1065,11 @@ class ProviderConfigurations(BaseModel): | |||
| def values(self) -> Iterator[ProviderConfiguration]: | |||
| return iter(self.configurations.values()) | |||
| def get(self, key, default=None): | |||
| return self.configurations.get(key, default) | |||
| def get(self, key, default=None) -> ProviderConfiguration | None: | |||
| if "/" not in key: | |||
| key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" | |||
| return self.configurations.get(key, default) # type: ignore | |||
| class ProviderModelBundle(BaseModel): | |||
| @@ -1061,7 +1078,6 @@ class ProviderModelBundle(BaseModel): | |||
| """ | |||
| configuration: ProviderConfiguration | |||
| provider_instance: ModelProvider | |||
| model_type_instance: AIModel | |||
| # pydantic configs | |||
| @@ -1,10 +1,34 @@ | |||
| from enum import Enum | |||
| from typing import Optional | |||
| from typing import Optional, Union | |||
| from pydantic import BaseModel, ConfigDict | |||
| from pydantic import BaseModel, ConfigDict, Field | |||
| from core.entities.parameter_entities import ( | |||
| AppSelectorScope, | |||
| CommonParameterType, | |||
| ModelSelectorScope, | |||
| ToolSelectorScope, | |||
| ) | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from models.provider import ProviderQuotaType | |||
| from core.tools.entities.common_entities import I18nObject | |||
| class ProviderQuotaType(Enum): | |||
| PAID = "paid" | |||
| """hosted paid quota""" | |||
| FREE = "free" | |||
| """third-party free quota""" | |||
| TRIAL = "trial" | |||
| """hosted trial quota""" | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in ProviderQuotaType: | |||
| if member.value == value: | |||
| return member | |||
| raise ValueError(f"No matching enum found for value '{value}'") | |||
| class QuotaUnit(Enum): | |||
| @@ -108,3 +132,55 @@ class ModelSettings(BaseModel): | |||
| # pydantic configs | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| class BasicProviderConfig(BaseModel): | |||
| """ | |||
| Base model class for common provider settings like credentials | |||
| """ | |||
| class Type(Enum): | |||
| SECRET_INPUT = CommonParameterType.SECRET_INPUT.value | |||
| TEXT_INPUT = CommonParameterType.TEXT_INPUT.value | |||
| SELECT = CommonParameterType.SELECT.value | |||
| BOOLEAN = CommonParameterType.BOOLEAN.value | |||
| APP_SELECTOR = CommonParameterType.APP_SELECTOR.value | |||
| MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "ProviderConfig.Type": | |||
| """ | |||
| Get value of given mode. | |||
| :param value: mode value | |||
| :return: mode | |||
| """ | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f"invalid mode value {value}") | |||
| type: Type = Field(..., description="The type of the credentials") | |||
| name: str = Field(..., description="The name of the credentials") | |||
| class ProviderConfig(BasicProviderConfig): | |||
| """ | |||
| Model class for common provider settings like credentials | |||
| """ | |||
| class Option(BaseModel): | |||
| value: str = Field(..., description="The value of the option") | |||
| label: I18nObject = Field(..., description="The label of the option") | |||
| scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None | |||
| required: bool = False | |||
| default: Optional[Union[int, str]] = None | |||
| options: Optional[list[Option]] = None | |||
| label: Optional[I18nObject] = None | |||
| help: Optional[I18nObject] = None | |||
| url: Optional[str] = None | |||
| placeholder: Optional[I18nObject] = None | |||
| def to_basic_provider_config(self) -> BasicProviderConfig: | |||
| return BasicProviderConfig(type=self.type, name=self.name) | |||
| @@ -20,6 +20,41 @@ def get_signed_file_url(upload_file_id: str) -> str: | |||
| return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" | |||
| def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: | |||
| url = f"{dify_config.FILES_URL}/files/upload/for-plugin" | |||
| if user_id is None: | |||
| user_id = "DEFAULT-USER" | |||
| timestamp = str(int(time.time())) | |||
| nonce = os.urandom(16).hex() | |||
| key = dify_config.SECRET_KEY.encode() | |||
| msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" | |||
| sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() | |||
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |||
| return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}" | |||
| def verify_plugin_file_signature( | |||
| *, filename: str, mimetype: str, tenant_id: str, user_id: str | None, timestamp: str, nonce: str, sign: str | |||
| ) -> bool: | |||
| if user_id is None: | |||
| user_id = "DEFAULT-USER" | |||
| data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" | |||
| secret_key = dify_config.SECRET_KEY.encode() | |||
| recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||
| recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() | |||
| # verify signature | |||
| if sign != recalculated_encoded_sign: | |||
| return False | |||
| current_time = int(time.time()) | |||
| return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT | |||
| def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: | |||
| data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" | |||
| secret_key = dify_config.SECRET_KEY.encode() | |||
| @@ -1,5 +1,5 @@ | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Optional | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, Field, model_validator | |||
| @@ -124,6 +124,17 @@ class File(BaseModel): | |||
| tool_file_id=self.related_id, extension=self.extension | |||
| ) | |||
| def to_plugin_parameter(self) -> dict[str, Any]: | |||
| return { | |||
| "dify_model_identity": FILE_MODEL_IDENTITY, | |||
| "mime_type": self.mime_type, | |||
| "filename": self.filename, | |||
| "extension": self.extension, | |||
| "size": self.size, | |||
| "type": self.type, | |||
| "url": self.generate_url(), | |||
| } | |||
| @model_validator(mode="after") | |||
| def validate_after(self): | |||
| match self.transfer_method: | |||
| @@ -0,0 +1,69 @@ | |||
| import base64 | |||
| import logging | |||
| import time | |||
| from typing import Optional | |||
| from configs import dify_config | |||
| from core.helper.url_signer import UrlSigner | |||
| from extensions.ext_storage import storage | |||
| IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] | |||
| IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) | |||
| class UploadFileParser: | |||
| @classmethod | |||
| def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: | |||
| if not upload_file: | |||
| return None | |||
| if upload_file.extension not in IMAGE_EXTENSIONS: | |||
| return None | |||
| if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url: | |||
| return cls.get_signed_temp_image_url(upload_file.id) | |||
| else: | |||
| # get image file base64 | |||
| try: | |||
| data = storage.load(upload_file.key) | |||
| except FileNotFoundError: | |||
| logging.exception(f"File not found: {upload_file.key}") | |||
| return None | |||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||
| return f"data:{upload_file.mime_type};base64,{encoded_string}" | |||
| @classmethod | |||
| def get_signed_temp_image_url(cls, upload_file_id) -> str: | |||
| """ | |||
| get signed url from upload file | |||
| :param upload_file: UploadFile object | |||
| :return: | |||
| """ | |||
| base_url = dify_config.FILES_URL | |||
| image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" | |||
| return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview") | |||
| @classmethod | |||
| def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: | |||
| """ | |||
| verify signature | |||
| :param upload_file_id: file id | |||
| :param timestamp: timestamp | |||
| :param nonce: nonce | |||
| :param sign: signature | |||
| :return: | |||
| """ | |||
| result = UrlSigner.verify( | |||
| sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview" | |||
| ) | |||
| # verify signature | |||
| if not result: | |||
| return False | |||
| current_time = int(time.time()) | |||
| return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT | |||
| @@ -0,0 +1,17 @@ | |||
| from core.helper import ssrf_proxy | |||
| def download_with_size_limit(url, max_download_size: int, **kwargs): | |||
| response = ssrf_proxy.get(url, follow_redirects=True, **kwargs) | |||
| if response.status_code == 404: | |||
| raise ValueError("file not found") | |||
| total_size = 0 | |||
| chunks = [] | |||
| for chunk in response.iter_bytes(): | |||
| total_size += len(chunk) | |||
| if total_size > max_download_size: | |||
| raise ValueError("Max file size reached") | |||
| chunks.append(chunk) | |||
| content = b"".join(chunks) | |||
| return content | |||
| @@ -0,0 +1,35 @@ | |||
| from collections.abc import Sequence | |||
| import requests | |||
| from yarl import URL | |||
| from configs import dify_config | |||
| from core.helper.download import download_with_size_limit | |||
| from core.plugin.entities.marketplace import MarketplacePluginDeclaration | |||
| def get_plugin_pkg_url(plugin_unique_identifier: str): | |||
| return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query( | |||
| unique_identifier=plugin_unique_identifier | |||
| ) | |||
| def download_plugin_pkg(plugin_unique_identifier: str): | |||
| url = str(get_plugin_pkg_url(plugin_unique_identifier)) | |||
| return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE) | |||
| def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]: | |||
| if len(plugin_ids) == 0: | |||
| return [] | |||
| url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch") | |||
| response = requests.post(url, json={"plugin_ids": plugin_ids}) | |||
| response.raise_for_status() | |||
| return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] | |||
| def record_install_plugin_event(plugin_unique_identifier: str): | |||
| url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count") | |||
| response = requests.post(url, json={"unique_identifier": plugin_unique_identifier}) | |||
| response.raise_for_status() | |||
| @@ -1,28 +1,35 @@ | |||
| import logging | |||
| import random | |||
| from typing import cast | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.entities import DEFAULT_PLUGIN_ID | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.invoke import InvokeBadRequestError | |||
| from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel | |||
| from core.model_runtime.model_providers.__base.moderation_model import ModerationModel | |||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||
| from extensions.ext_hosting_provider import hosting_configuration | |||
| from models.provider import ProviderType | |||
| logger = logging.getLogger(__name__) | |||
| def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: | |||
| def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: | |||
| moderation_config = hosting_configuration.moderation_config | |||
| openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai" | |||
| if ( | |||
| moderation_config | |||
| and moderation_config.enabled is True | |||
| and "openai" in hosting_configuration.provider_map | |||
| and hosting_configuration.provider_map["openai"].enabled is True | |||
| and openai_provider_name in hosting_configuration.provider_map | |||
| and hosting_configuration.provider_map[openai_provider_name].enabled is True | |||
| ): | |||
| using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type | |||
| provider_name = model_config.provider | |||
| if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: | |||
| hosting_openai_config = hosting_configuration.provider_map["openai"] | |||
| assert hosting_openai_config is not None | |||
| hosting_openai_config = hosting_configuration.provider_map[openai_provider_name] | |||
| if hosting_openai_config.credentials is None: | |||
| return False | |||
| # 2000 text per chunk | |||
| length = 2000 | |||
| @@ -34,15 +41,20 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) | |||
| text_chunk = random.choice(text_chunks) | |||
| try: | |||
| model_type_instance = OpenAIModerationModel() | |||
| # FIXME, for type hint using assert or raise ValueError is better here? | |||
| model_provider_factory = ModelProviderFactory(tenant_id) | |||
| # Get model instance of LLM | |||
| model_type_instance = model_provider_factory.get_model_type_instance( | |||
| provider=openai_provider_name, model_type=ModelType.MODERATION | |||
| ) | |||
| model_type_instance = cast(ModerationModel, model_type_instance) | |||
| moderation_result = model_type_instance.invoke( | |||
| model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk | |||
| model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk | |||
| ) | |||
| if moderation_result is True: | |||
| return True | |||
| except Exception as ex: | |||
| except Exception: | |||
| logger.exception(f"Fails to check moderation, provider_name: {provider_name}") | |||
| raise InvokeBadRequestError("Rate limit exceeded, please try again later.") | |||
| @@ -36,7 +36,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| ) | |||
| retries = 0 | |||
| stream = kwargs.pop("stream", False) | |||
| while retries <= max_retries: | |||
| try: | |||
| if dify_config.SSRF_PROXY_ALL_URL: | |||
| @@ -8,6 +8,7 @@ from extensions.ext_redis import redis_client | |||
| class ToolProviderCredentialsCacheType(Enum): | |||
| PROVIDER = "tool_provider" | |||
| ENDPOINT = "endpoint" | |||
| class ToolProviderCredentialsCache: | |||
| @@ -0,0 +1,52 @@ | |||
| import base64 | |||
| import hashlib | |||
| import hmac | |||
| import os | |||
| import time | |||
| from pydantic import BaseModel, Field | |||
| from configs import dify_config | |||
| class SignedUrlParams(BaseModel): | |||
| sign_key: str = Field(..., description="The sign key") | |||
| timestamp: str = Field(..., description="Timestamp") | |||
| nonce: str = Field(..., description="Nonce") | |||
| sign: str = Field(..., description="Signature") | |||
| class UrlSigner: | |||
| @classmethod | |||
| def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str: | |||
| signed_url_params = cls.get_signed_url_params(sign_key, prefix) | |||
| return ( | |||
| f"{url}?timestamp={signed_url_params.timestamp}" | |||
| f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}" | |||
| ) | |||
| @classmethod | |||
| def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams: | |||
| timestamp = str(int(time.time())) | |||
| nonce = os.urandom(16).hex() | |||
| sign = cls._sign(sign_key, timestamp, nonce, prefix) | |||
| return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign) | |||
| @classmethod | |||
| def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool: | |||
| recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix) | |||
| return sign == recalculated_sign | |||
| @classmethod | |||
| def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str: | |||
| if not dify_config.SECRET_KEY: | |||
| raise Exception("SECRET_KEY is not set") | |||
| data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}" | |||
| secret_key = dify_config.SECRET_KEY.encode() | |||
| sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |||
| return encoded_sign | |||
| @@ -4,9 +4,9 @@ from flask import Flask | |||
| from pydantic import BaseModel | |||
| from configs import dify_config | |||
| from core.entities.provider_entities import QuotaUnit, RestrictModel | |||
| from core.entities import DEFAULT_PLUGIN_ID | |||
| from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from models.provider import ProviderQuotaType | |||
| class HostingQuota(BaseModel): | |||
| @@ -48,12 +48,12 @@ class HostingConfiguration: | |||
| if dify_config.EDITION != "CLOUD": | |||
| return | |||
| self.provider_map["azure_openai"] = self.init_azure_openai() | |||
| self.provider_map["openai"] = self.init_openai() | |||
| self.provider_map["anthropic"] = self.init_anthropic() | |||
| self.provider_map["minimax"] = self.init_minimax() | |||
| self.provider_map["spark"] = self.init_spark() | |||
| self.provider_map["zhipuai"] = self.init_zhipuai() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark() | |||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai() | |||
| self.moderation_config = self.init_moderation_config() | |||
| @@ -240,7 +240,14 @@ class HostingConfiguration: | |||
| @staticmethod | |||
| def init_moderation_config() -> HostedModerationConfig: | |||
| if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS: | |||
| return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(",")) | |||
| providers = dify_config.HOSTED_MODERATION_PROVIDERS.split(",") | |||
| hosted_providers = [] | |||
| for provider in providers: | |||
| if "/" not in provider: | |||
| provider = f"{DEFAULT_PLUGIN_ID}/{provider}/{provider}" | |||
| hosted_providers.append(provider) | |||
| return HostedModerationConfig(enabled=True, providers=hosted_providers) | |||
| return HostedModerationConfig(enabled=False) | |||
| @@ -30,7 +30,7 @@ from core.rag.splitter.fixed_text_splitter import ( | |||
| FixedRecursiveCharacterTextSplitter, | |||
| ) | |||
| from core.rag.splitter.text_splitter import TextSplitter | |||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | |||
| from core.tools.utils.rag_web_reader import get_image_upload_file_ids | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from extensions.ext_storage import storage | |||
| @@ -618,10 +618,8 @@ class IndexingRunner: | |||
| tokens = 0 | |||
| if embedding_model_instance: | |||
| tokens += sum( | |||
| embedding_model_instance.get_text_embedding_num_tokens([document.page_content]) | |||
| for document in chunk_documents | |||
| ) | |||
| page_content_list = [document.page_content for document in chunk_documents] | |||
| tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list)) | |||
| # load index | |||
| index_processor.load(dataset, chunk_documents, with_keywords=False) | |||
| @@ -48,7 +48,7 @@ class LLMGenerator: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False | |||
| prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False | |||
| ), | |||
| ) | |||
| answer = cast(str, response.message.content) | |||
| @@ -101,7 +101,7 @@ class LLMGenerator: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| prompt_messages=list(prompt_messages), | |||
| model_parameters={"max_tokens": 256, "temperature": 0}, | |||
| stream=False, | |||
| ), | |||
| @@ -110,7 +110,7 @@ class LLMGenerator: | |||
| questions = output_parser.parse(cast(str, response.message.content)) | |||
| except InvokeError: | |||
| questions = [] | |||
| except Exception as e: | |||
| except Exception: | |||
| logging.exception("Failed to generate suggested questions after answer") | |||
| questions = [] | |||
| @@ -150,7 +150,7 @@ class LLMGenerator: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| ) | |||
| @@ -200,7 +200,7 @@ class LLMGenerator: | |||
| prompt_content = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| ) | |||
| except InvokeError as e: | |||
| @@ -236,7 +236,7 @@ class LLMGenerator: | |||
| parameter_content = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False | |||
| prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| ) | |||
| rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) | |||
| @@ -248,7 +248,7 @@ class LLMGenerator: | |||
| statement_content = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=statement_messages, model_parameters=model_parameters, stream=False | |||
| prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| ) | |||
| rule_config["opening_statement"] = cast(str, statement_content.message.content) | |||
| @@ -301,7 +301,7 @@ class LLMGenerator: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| ) | |||
| @@ -1,6 +1,6 @@ | |||
| import logging | |||
| from collections.abc import Callable, Generator, Iterable, Sequence | |||
| from typing import IO, Any, Optional, Union, cast | |||
| from typing import IO, Any, Literal, Optional, Union, cast, overload | |||
| from configs import dify_config | |||
| from core.entities.embedding_type import EmbeddingInputType | |||
| @@ -98,6 +98,42 @@ class ModelInstance: | |||
| return None | |||
| @overload | |||
| def invoke_llm( | |||
| self, | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: Optional[dict] = None, | |||
| tools: Sequence[PromptMessageTool] | None = None, | |||
| stop: Optional[list[str]] = None, | |||
| stream: Literal[True] = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| ) -> Generator: ... | |||
| @overload | |||
| def invoke_llm( | |||
| self, | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: Optional[dict] = None, | |||
| tools: Sequence[PromptMessageTool] | None = None, | |||
| stop: Optional[list[str]] = None, | |||
| stream: Literal[False] = False, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| ) -> LLMResult: ... | |||
| @overload | |||
| def invoke_llm( | |||
| self, | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: Optional[dict] = None, | |||
| tools: Sequence[PromptMessageTool] | None = None, | |||
| stop: Optional[list[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| ) -> Union[LLMResult, Generator]: ... | |||
| def invoke_llm( | |||
| self, | |||
| prompt_messages: Sequence[PromptMessage], | |||
| @@ -192,7 +228,7 @@ class ModelInstance: | |||
| ), | |||
| ) | |||
| def get_text_embedding_num_tokens(self, texts: list[str]) -> int: | |||
| def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: | |||
| """ | |||
| Get number of tokens for text embedding | |||
| @@ -204,7 +240,7 @@ class ModelInstance: | |||
| self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) | |||
| return cast( | |||
| int, | |||
| list[int], | |||
| self._round_robin_invoke( | |||
| function=self.model_type_instance.get_num_tokens, | |||
| model=self.model, | |||
| @@ -397,7 +433,7 @@ class ModelManager: | |||
| return ModelInstance(provider_model_bundle, model) | |||
| def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: | |||
| def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: | |||
| """ | |||
| Return first provider and the first model in the provider | |||
| :param tenant_id: tenant id | |||