Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: xhe <xw897002528@gmail.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: takatost <takatost@gmail.com> Co-authored-by: kurokobo <kuro664@gmail.com> Co-authored-by: Novice Lee <novicelee@NoviPro.local> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: AkaraChen <akarachen@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com> Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com> Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Novice <857526207@qq.com> Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com> Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com> Co-authored-by: eux <euxuuu@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: lotsik <lotsik@mail.ru> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com> Co-authored-by: CN-P5 <heibai2006@gmail.com> Co-authored-by: CN-P5 <heibai2006@qq.com> Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Boris Feld <lothiraldan@gmail.com> Co-authored-by: mbo <himabo@gmail.com> Co-authored-by: mabo <mabo@aeyes.ai> Co-authored-by: Warren Chen <warren.chen830@gmail.com> Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com> Co-authored-by: jiandanfeng <chenjh3@wangsu.com> Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com> Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com> Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com> Co-authored-by: Xu Song <xusong.vip@gmail.com> Co-authored-by: rayshaw001 <396301947@163.com> Co-authored-by: Ding Jiatong <dingjiatong@gmail.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: JasonVV <jasonwangiii@outlook.com> Co-authored-by: le0zh <newlight@qq.com> Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com> Co-authored-by: k-zaku <zaku99@outlook.jp> Co-authored-by: luckylhb90 <luckylhb90@gmail.com> Co-authored-by: hobo.l <hobo.l@binance.com> Co-authored-by: jiangbo721 <365065261@qq.com> Co-authored-by: 刘江波 <jiangbo721@163.com> Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com> Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: sino <sino2322@gmail.com> Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com> Co-authored-by: lowell <lowell.hu@zkteco.in> Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com> Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com> Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com> Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com> Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com> Co-authored-by: Jason <ggbbddjm@gmail.com> Co-authored-by: Xin Zhang <sjhpzx@gmail.com> Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com> Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com> Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com> Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com> Co-authored-by: Yingchun Lai <laiyingchun@apache.org> Co-authored-by: Hash Brown <hi@xzd.me> Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com> Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com> Co-authored-by: aplio <ryo.091219@gmail.com> Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com> Co-authored-by: Nam Vu <zuzoovn@gmail.com> Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com> Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com> Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com> Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com> Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp> Co-authored-by: HQidea <HQidea@users.noreply.github.com> Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com> Co-authored-by: xhe <xw897002528@gmail.com> Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com> Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com> Co-authored-by: engchina <12236799+engchina@users.noreply.github.com> Co-authored-by: engchina <atjapan2015@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: 呆萌闷油瓶 <253605712@qq.com> Co-authored-by: Kemal <kemalmeler@outlook.com> Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com> Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com> Co-authored-by: steven <sunzwj@digitalchina.com> Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com> Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com> Co-authored-by: 胡春东 <gycm520@gmail.com> Co-authored-by: Junjie.M <118170653@qq.com> Co-authored-by: MuYu <mr.muzea@gmail.com> Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com> Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com> Co-authored-by: Fei He <droxer.he@gmail.com> Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: douxc <7553076+douxc@users.noreply.github.com> Co-authored-by: liuzhenghua <1090179900@qq.com> Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com> Co-authored-by: AugNSo <song.tiankai@icloud.com> Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com> Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com> Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com> Co-authored-by: Hundredwz <1808096180@qq.com> Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>tags/1.0.0
| #!/bin/bash | #!/bin/bash | ||||
| cd web && npm install | |||||
| npm add -g pnpm@9.12.2 | |||||
| cd web && pnpm install | |||||
| pipx install poetry | pipx install poetry | ||||
| echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc | echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc | ||||
| echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc | echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc | ||||
| echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc | |||||
| echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc | |||||
| echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc | echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc | ||||
| echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc | echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc | ||||
| - name: Run Unit tests | - name: Run Unit tests | ||||
| run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh | run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh | ||||
| - name: Run ModelRuntime | |||||
| run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh | |||||
| - name: Run dify config tests | - name: Run dify config tests | ||||
| run: poetry run -P api python dev/pytest/pytest_config_tests.py | run: poetry run -P api python dev/pytest/pytest_config_tests.py | ||||
| - name: Run Tool | |||||
| run: poetry run -P api bash dev/pytest/pytest_tools.sh | |||||
| - name: Run mypy | - name: Run mypy | ||||
| run: | | run: | | ||||
| poetry run -C api python -m mypy --install-types --non-interactive . | poetry run -C api python -m mypy --install-types --non-interactive . | 
| pull_request: | pull_request: | ||||
| branches: | branches: | ||||
| - main | - main | ||||
| - plugins/beta | |||||
| paths: | paths: | ||||
| - api/migrations/** | - api/migrations/** | ||||
| - .github/workflows/db-migration-test.yml | - .github/workflows/db-migration-test.yml | 
| with: | with: | ||||
| files: web/** | files: web/** | ||||
| - name: Install pnpm | |||||
| uses: pnpm/action-setup@v4 | |||||
| with: | |||||
| version: 10 | |||||
| run_install: false | |||||
| - name: Setup NodeJS | - name: Setup NodeJS | ||||
| uses: actions/setup-node@v4 | uses: actions/setup-node@v4 | ||||
| if: steps.changed-files.outputs.any_changed == 'true' | if: steps.changed-files.outputs.any_changed == 'true' | ||||
| with: | with: | ||||
| node-version: 20 | node-version: 20 | ||||
| cache: yarn | |||||
| cache: pnpm | |||||
| cache-dependency-path: ./web/package.json | cache-dependency-path: ./web/package.json | ||||
| - name: Web dependencies | - name: Web dependencies | ||||
| if: steps.changed-files.outputs.any_changed == 'true' | if: steps.changed-files.outputs.any_changed == 'true' | ||||
| run: yarn install --frozen-lockfile | |||||
| run: pnpm install --frozen-lockfile | |||||
| - name: Web style check | - name: Web style check | ||||
| if: steps.changed-files.outputs.any_changed == 'true' | if: steps.changed-files.outputs.any_changed == 'true' | 
| with: | with: | ||||
| node-version: ${{ matrix.node-version }} | node-version: ${{ matrix.node-version }} | ||||
| cache: '' | cache: '' | ||||
| cache-dependency-path: 'yarn.lock' | |||||
| cache-dependency-path: 'pnpm-lock.yaml' | |||||
| - name: Install Dependencies | - name: Install Dependencies | ||||
| run: yarn install | |||||
| run: pnpm install | |||||
| - name: Test | - name: Test | ||||
| run: yarn test | |||||
| run: pnpm test | 
| - name: Install dependencies | - name: Install dependencies | ||||
| if: env.FILES_CHANGED == 'true' | if: env.FILES_CHANGED == 'true' | ||||
| run: yarn install --frozen-lockfile | |||||
| run: pnpm install --frozen-lockfile | |||||
| - name: Run npm script | - name: Run npm script | ||||
| if: env.FILES_CHANGED == 'true' | if: env.FILES_CHANGED == 'true' | ||||
| run: npm run auto-gen-i18n | |||||
| run: pnpm run auto-gen-i18n | |||||
| - name: Create Pull Request | - name: Create Pull Request | ||||
| if: env.FILES_CHANGED == 'true' | if: env.FILES_CHANGED == 'true' | 
| if: steps.changed-files.outputs.any_changed == 'true' | if: steps.changed-files.outputs.any_changed == 'true' | ||||
| with: | with: | ||||
| node-version: 20 | node-version: 20 | ||||
| cache: yarn | |||||
| cache: pnpm | |||||
| cache-dependency-path: ./web/package.json | cache-dependency-path: ./web/package.json | ||||
| - name: Install dependencies | - name: Install dependencies | ||||
| if: steps.changed-files.outputs.any_changed == 'true' | if: steps.changed-files.outputs.any_changed == 'true' | ||||
| run: yarn install --frozen-lockfile | |||||
| run: pnpm install --frozen-lockfile | |||||
| - name: Run tests | - name: Run tests | ||||
| if: steps.changed-files.outputs.any_changed == 'true' | if: steps.changed-files.outputs.any_changed == 'true' | ||||
| run: yarn test | |||||
| run: pnpm test | 
| docker/volumes/pgvecto_rs/data/* | docker/volumes/pgvecto_rs/data/* | ||||
| docker/volumes/couchbase/* | docker/volumes/couchbase/* | ||||
| docker/volumes/oceanbase/* | docker/volumes/oceanbase/* | ||||
| docker/volumes/plugin_daemon/* | |||||
| !docker/volumes/oceanbase/init.d | !docker/volumes/oceanbase/init.d | ||||
| docker/nginx/conf.d/default.conf | docker/nginx/conf.d/default.conf | ||||
| .idea/ | .idea/ | ||||
| .vscode | .vscode | ||||
| # pnpm | |||||
| /.pnpm-store | |||||
| # plugin migrate | |||||
| plugins.jsonl | 
| .env | .env | ||||
| *.env.* | *.env.* | ||||
| storage/generate_files/* | |||||
| storage/privkeys/* | storage/privkeys/* | ||||
| storage/tools/* | |||||
| storage/upload_files/* | |||||
| # Logs | # Logs | ||||
| logs | logs | ||||
| # jetbrains | # jetbrains | ||||
| .idea | .idea | ||||
| .mypy_cache | |||||
| .ruff_cache | |||||
| # venv | # venv | ||||
| .venv | .venv | 
| APP_MAX_EXECUTION_TIME=1200 | APP_MAX_EXECUTION_TIME=1200 | ||||
| APP_MAX_ACTIVE_REQUESTS=0 | APP_MAX_ACTIVE_REQUESTS=0 | ||||
| # Celery beat configuration | # Celery beat configuration | ||||
| CELERY_BEAT_SCHEDULER_TIME=1 | CELERY_BEAT_SCHEDULER_TIME=1 | ||||
| POSITION_PROVIDER_INCLUDES= | POSITION_PROVIDER_INCLUDES= | ||||
| POSITION_PROVIDER_EXCLUDES= | POSITION_PROVIDER_EXCLUDES= | ||||
| # Plugin configuration | |||||
| PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi | |||||
| PLUGIN_DAEMON_URL=http://127.0.0.1:5002 | |||||
| PLUGIN_REMOTE_INSTALL_PORT=5003 | |||||
| PLUGIN_REMOTE_INSTALL_HOST=localhost | |||||
| PLUGIN_MAX_PACKAGE_SIZE=15728640 | |||||
| INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 | |||||
| INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 | |||||
| # Marketplace configuration | |||||
| MARKETPLACE_ENABLED=true | |||||
| MARKETPLACE_API_URL=https://marketplace.dify.ai | |||||
| # Endpoint configuration | |||||
| ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} | |||||
| # Reset password token expiry minutes | # Reset password token expiry minutes | ||||
| RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 | RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 | ||||
| # Download nltk data | # Download nltk data | ||||
| RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')" | RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')" | ||||
| ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache | |||||
| RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" | |||||
| # Copy source code | # Copy source code | ||||
| COPY . /app/api/ | COPY . /app/api/ | ||||
| from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation | from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation | ||||
| from models.provider import Provider, ProviderModel | from models.provider import Provider, ProviderModel | ||||
| from services.account_service import RegisterService, TenantService | from services.account_service import RegisterService, TenantService | ||||
| from services.plugin.data_migration import PluginDataMigration | |||||
| from services.plugin.plugin_migration import PluginMigration | |||||
| @click.command("reset-password", help="Reset the account password.") | @click.command("reset-password", help="Reset the account password.") | ||||
| ) | ) | ||||
| ) | ) | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| click.echo(click.style("Failed to create Qdrant client.", fg="red")) | click.echo(click.style("Failed to create Qdrant client.", fg="red")) | ||||
| click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green")) | click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green")) | ||||
| click.echo(click.style("Database migration successful!", fg="green")) | click.echo(click.style("Database migration successful!", fg="green")) | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| logging.exception("Failed to execute database migration") | logging.exception("Failed to execute database migration") | ||||
| finally: | finally: | ||||
| lock.release() | lock.release() | ||||
| account = accounts[0] | account = accounts[0] | ||||
| print("Fixing missing site for app {}".format(app.id)) | print("Fixing missing site for app {}".format(app.id)) | ||||
| app_was_created.send(app, account=account) | app_was_created.send(app, account=account) | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| failed_app_ids.append(app_id) | failed_app_ids.append(app_id) | ||||
| click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red")) | click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red")) | ||||
| logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}") | logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}") | ||||
| break | break | ||||
| click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) | click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) | ||||
| @click.command("migrate-data-for-plugin", help="Migrate data for plugin.") | |||||
| def migrate_data_for_plugin(): | |||||
| """ | |||||
| Migrate data for plugin. | |||||
| """ | |||||
| click.echo(click.style("Starting migrate data for plugin.", fg="white")) | |||||
| PluginDataMigration.migrate() | |||||
| click.echo(click.style("Migrate data for plugin completed.", fg="green")) | |||||
| @click.command("extract-plugins", help="Extract plugins.") | |||||
| @click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl") | |||||
| @click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10) | |||||
| def extract_plugins(output_file: str, workers: int): | |||||
| """ | |||||
| Extract plugins. | |||||
| """ | |||||
| click.echo(click.style("Starting extract plugins.", fg="white")) | |||||
| PluginMigration.extract_plugins(output_file, workers) | |||||
| click.echo(click.style("Extract plugins completed.", fg="green")) | |||||
| @click.command("extract-unique-identifiers", help="Extract unique identifiers.") | |||||
| @click.option( | |||||
| "--output_file", | |||||
| prompt=True, | |||||
| help="The file to store the extracted unique identifiers.", | |||||
| default="unique_identifiers.json", | |||||
| ) | |||||
| @click.option( | |||||
| "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" | |||||
| ) | |||||
| def extract_unique_plugins(output_file: str, input_file: str): | |||||
| """ | |||||
| Extract unique plugins. | |||||
| """ | |||||
| click.echo(click.style("Starting extract unique plugins.", fg="white")) | |||||
| PluginMigration.extract_unique_plugins_to_file(input_file, output_file) | |||||
| click.echo(click.style("Extract unique plugins completed.", fg="green")) | |||||
| @click.command("install-plugins", help="Install plugins.") | |||||
| @click.option( | |||||
| "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" | |||||
| ) | |||||
| @click.option( | |||||
| "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" | |||||
| ) | |||||
| def install_plugins(input_file: str, output_file: str): | |||||
| """ | |||||
| Install plugins. | |||||
| """ | |||||
| click.echo(click.style("Starting install plugins.", fg="white")) | |||||
| PluginMigration.install_plugins(input_file, output_file) | |||||
| click.echo(click.style("Install plugins completed.", fg="green")) | 
| ) | ) | ||||
| class PluginConfig(BaseSettings): | |||||
| """ | |||||
| Plugin configs | |||||
| """ | |||||
| PLUGIN_DAEMON_URL: HttpUrl = Field( | |||||
| description="Plugin API URL", | |||||
| default="http://localhost:5002", | |||||
| ) | |||||
| PLUGIN_DAEMON_KEY: str = Field( | |||||
| description="Plugin API key", | |||||
| default="plugin-api-key", | |||||
| ) | |||||
| INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") | |||||
| PLUGIN_REMOTE_INSTALL_HOST: str = Field( | |||||
| description="Plugin Remote Install Host", | |||||
| default="localhost", | |||||
| ) | |||||
| PLUGIN_REMOTE_INSTALL_PORT: PositiveInt = Field( | |||||
| description="Plugin Remote Install Port", | |||||
| default=5003, | |||||
| ) | |||||
| PLUGIN_MAX_PACKAGE_SIZE: PositiveInt = Field( | |||||
| description="Maximum allowed size for plugin packages in bytes", | |||||
| default=15728640, | |||||
| ) | |||||
| PLUGIN_MAX_BUNDLE_SIZE: PositiveInt = Field( | |||||
| description="Maximum allowed size for plugin bundles in bytes", | |||||
| default=15728640 * 12, | |||||
| ) | |||||
| class MarketplaceConfig(BaseSettings): | |||||
| """ | |||||
| Configuration for marketplace | |||||
| """ | |||||
| MARKETPLACE_ENABLED: bool = Field( | |||||
| description="Enable or disable marketplace", | |||||
| default=True, | |||||
| ) | |||||
| MARKETPLACE_API_URL: HttpUrl = Field( | |||||
| description="Marketplace API URL", | |||||
| default="https://marketplace.dify.ai", | |||||
| ) | |||||
| class EndpointConfig(BaseSettings): | class EndpointConfig(BaseSettings): | ||||
| """ | """ | ||||
| Configuration for various application endpoints and URLs | Configuration for various application endpoints and URLs | ||||
| default="", | default="", | ||||
| ) | ) | ||||
| ENDPOINT_URL_TEMPLATE: str = Field( | |||||
| description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}" | |||||
| ) | |||||
| class FileAccessConfig(BaseSettings): | class FileAccessConfig(BaseSettings): | ||||
| """ | """ | ||||
| AuthConfig, # Changed from OAuthConfig to AuthConfig | AuthConfig, # Changed from OAuthConfig to AuthConfig | ||||
| BillingConfig, | BillingConfig, | ||||
| CodeExecutionSandboxConfig, | CodeExecutionSandboxConfig, | ||||
| PluginConfig, | |||||
| MarketplaceConfig, | |||||
| DataSetConfig, | DataSetConfig, | ||||
| EndpointConfig, | EndpointConfig, | ||||
| FileAccessConfig, | FileAccessConfig, | 
| CURRENT_VERSION: str = Field( | CURRENT_VERSION: str = Field( | ||||
| description="Dify version", | description="Dify version", | ||||
| default="0.15.3", | |||||
| default="1.0.0", | |||||
| ) | ) | ||||
| COMMIT_SHA: str = Field( | COMMIT_SHA: str = Field( | 
| from contextvars import ContextVar | from contextvars import ContextVar | ||||
| from threading import Lock | |||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from core.plugin.entities.plugin_daemon import PluginModelProviderEntity | |||||
| from core.tools.plugin_tool.provider import PluginToolProviderController | |||||
| from core.workflow.entities.variable_pool import VariablePool | from core.workflow.entities.variable_pool import VariablePool | ||||
| tenant_id: ContextVar[str] = ContextVar("tenant_id") | tenant_id: ContextVar[str] = ContextVar("tenant_id") | ||||
| workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") | workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") | ||||
| plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers") | |||||
| plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock") | |||||
| plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers") | |||||
| plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock") | 
| from libs.external_api import ExternalApi | from libs.external_api import ExternalApi | ||||
| from .app.app_import import AppImportApi, AppImportConfirmApi | |||||
| from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi | |||||
| from .explore.audio import ChatAudioApi, ChatTextApi | from .explore.audio import ChatAudioApi, ChatTextApi | ||||
| from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi | from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi | ||||
| from .explore.conversation import ( | from .explore.conversation import ( | ||||
| # Import App | # Import App | ||||
| api.add_resource(AppImportApi, "/apps/imports") | api.add_resource(AppImportApi, "/apps/imports") | ||||
| api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm") | api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm") | ||||
| api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies") | |||||
| # Import other controllers | # Import other controllers | ||||
| from . import admin, apikey, extension, feature, ping, setup, version | from . import admin, apikey, extension, feature, ping, setup, version | ||||
| from .tag import tags | from .tag import tags | ||||
| # Import workspace controllers | # Import workspace controllers | ||||
| from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace | |||||
| from .workspace import ( | |||||
| account, | |||||
| agent_providers, | |||||
| endpoint, | |||||
| load_balancing_config, | |||||
| members, | |||||
| model_providers, | |||||
| models, | |||||
| plugin, | |||||
| tool_providers, | |||||
| workspace, | |||||
| ) | 
| from flask import request | from flask import request | ||||
| from flask_restful import Resource, reqparse # type: ignore | from flask_restful import Resource, reqparse # type: ignore | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session | |||||
| from werkzeug.exceptions import NotFound, Unauthorized | from werkzeug.exceptions import NotFound, Unauthorized | ||||
| from configs import dify_config | from configs import dify_config | ||||
| parser.add_argument("position", type=int, required=True, nullable=False, location="json") | parser.add_argument("position", type=int, required=True, nullable=False, location="json") | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| app = App.query.filter(App.id == args["app_id"]).first() | |||||
| with Session(db.engine) as session: | |||||
| app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() | |||||
| if not app: | if not app: | ||||
| raise NotFound(f"App '{args['app_id']}' is not found") | raise NotFound(f"App '{args['app_id']}' is not found") | ||||
| privacy_policy = site.privacy_policy or args["privacy_policy"] or "" | privacy_policy = site.privacy_policy or args["privacy_policy"] or "" | ||||
| custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" | custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" | ||||
| recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() | |||||
| with Session(db.engine) as session: | |||||
| recommended_app = session.execute( | |||||
| select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) | |||||
| ).scalar_one_or_none() | |||||
| if not recommended_app: | if not recommended_app: | ||||
| recommended_app = RecommendedApp( | recommended_app = RecommendedApp( | ||||
| @only_edition_cloud | @only_edition_cloud | ||||
| @admin_required | @admin_required | ||||
| def delete(self, app_id): | def delete(self, app_id): | ||||
| recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() | |||||
| with Session(db.engine) as session: | |||||
| recommended_app = session.execute( | |||||
| select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id)) | |||||
| ).scalar_one_or_none() | |||||
| if not recommended_app: | if not recommended_app: | ||||
| return {"result": "success"}, 204 | return {"result": "success"}, 204 | ||||
| app = App.query.filter(App.id == recommended_app.app_id).first() | |||||
| with Session(db.engine) as session: | |||||
| app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none() | |||||
| if app: | if app: | ||||
| app.is_public = False | app.is_public = False | ||||
| installed_apps = InstalledApp.query.filter( | |||||
| InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id | |||||
| ).all() | |||||
| with Session(db.engine) as session: | |||||
| installed_apps = session.execute( | |||||
| select(InstalledApp).filter( | |||||
| InstalledApp.app_id == recommended_app.app_id, | |||||
| InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id, | |||||
| ) | |||||
| ).all() | |||||
| for installed_app in installed_apps: | for installed_app in installed_apps: | ||||
| db.session.delete(installed_app) | db.session.delete(installed_app) | 
| import flask_restful # type: ignore | import flask_restful # type: ignore | ||||
| from flask_login import current_user # type: ignore | from flask_login import current_user # type: ignore | ||||
| from flask_restful import Resource, fields, marshal_with | from flask_restful import Resource, fields, marshal_with | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session | |||||
| from werkzeug.exceptions import Forbidden | from werkzeug.exceptions import Forbidden | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| def _get_resource(resource_id, tenant_id, resource_model): | def _get_resource(resource_id, tenant_id, resource_model): | ||||
| resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() | |||||
| if resource_model == App: | |||||
| with Session(db.engine) as session: | |||||
| resource = session.execute( | |||||
| select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) | |||||
| ).scalar_one_or_none() | |||||
| else: | |||||
| with Session(db.engine) as session: | |||||
| resource = session.execute( | |||||
| select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) | |||||
| ).scalar_one_or_none() | |||||
| if resource is None: | if resource is None: | ||||
| flask_restful.abort(404, message=f"{resource_model.__name__} not found.") | flask_restful.abort(404, message=f"{resource_model.__name__} not found.") | 
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from werkzeug.exceptions import Forbidden | from werkzeug.exceptions import Forbidden | ||||
| from controllers.console.app.wraps import get_app_model | |||||
| from controllers.console.wraps import ( | from controllers.console.wraps import ( | ||||
| account_initialization_required, | account_initialization_required, | ||||
| setup_required, | setup_required, | ||||
| ) | ) | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from fields.app_fields import app_import_fields | |||||
| from fields.app_fields import app_import_check_dependencies_fields, app_import_fields | |||||
| from libs.login import login_required | from libs.login import login_required | ||||
| from models import Account | from models import Account | ||||
| from models.model import App | |||||
| from services.app_dsl_service import AppDslService, ImportStatus | from services.app_dsl_service import AppDslService, ImportStatus | ||||
| if result.status == ImportStatus.FAILED.value: | if result.status == ImportStatus.FAILED.value: | ||||
| return result.model_dump(mode="json"), 400 | return result.model_dump(mode="json"), 400 | ||||
| return result.model_dump(mode="json"), 200 | return result.model_dump(mode="json"), 200 | ||||
| class AppImportCheckDependenciesApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @get_app_model | |||||
| @account_initialization_required | |||||
| @marshal_with(app_import_check_dependencies_fields) | |||||
| def get(self, app_model: App): | |||||
| if not current_user.is_editor: | |||||
| raise Forbidden() | |||||
| with Session(db.engine) as session: | |||||
| import_service = AppDslService(session) | |||||
| result = import_service.check_dependencies(app_model=app_model) | |||||
| return result.model_dump(mode="json"), 200 | 
| from flask_login import current_user # type: ignore | from flask_login import current_user # type: ignore | ||||
| from flask_restful import Resource, marshal_with, reqparse # type: ignore | from flask_restful import Resource, marshal_with, reqparse # type: ignore | ||||
| from sqlalchemy.orm import Session | |||||
| from werkzeug.exceptions import Forbidden, NotFound | from werkzeug.exceptions import Forbidden, NotFound | ||||
| from constants.languages import supported_language | from constants.languages import supported_language | ||||
| if not current_user.is_editor: | if not current_user.is_editor: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| site = Site.query.filter(Site.app_id == app_model.id).one_or_404() | |||||
| for attr_name in [ | |||||
| "title", | |||||
| "icon_type", | |||||
| "icon", | |||||
| "icon_background", | |||||
| "description", | |||||
| "default_language", | |||||
| "chat_color_theme", | |||||
| "chat_color_theme_inverted", | |||||
| "customize_domain", | |||||
| "copyright", | |||||
| "privacy_policy", | |||||
| "custom_disclaimer", | |||||
| "customize_token_strategy", | |||||
| "prompt_public", | |||||
| "show_workflow_steps", | |||||
| "use_icon_as_answer_icon", | |||||
| ]: | |||||
| value = args.get(attr_name) | |||||
| if value is not None: | |||||
| setattr(site, attr_name, value) | |||||
| site.updated_by = current_user.id | |||||
| site.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| db.session.commit() | |||||
| with Session(db.engine) as session: | |||||
| site = session.query(Site).filter(Site.app_id == app_model.id).first() | |||||
| if not site: | |||||
| raise NotFound | |||||
| for attr_name in [ | |||||
| "title", | |||||
| "icon_type", | |||||
| "icon", | |||||
| "icon_background", | |||||
| "description", | |||||
| "default_language", | |||||
| "chat_color_theme", | |||||
| "chat_color_theme_inverted", | |||||
| "customize_domain", | |||||
| "copyright", | |||||
| "privacy_policy", | |||||
| "custom_disclaimer", | |||||
| "customize_token_strategy", | |||||
| "prompt_public", | |||||
| "show_workflow_steps", | |||||
| "use_icon_as_answer_icon", | |||||
| ]: | |||||
| value = args.get(attr_name) | |||||
| if value is not None: | |||||
| setattr(site, attr_name, value) | |||||
| site.updated_by = current_user.id | |||||
| site.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| session.commit() | |||||
| return site | return site | ||||
| from libs.helper import TimestampField, uuid_value | from libs.helper import TimestampField, uuid_value | ||||
| from libs.login import current_user, login_required | from libs.login import current_user, login_required | ||||
| from models import App | from models import App | ||||
| from models.account import Account | |||||
| from models.model import AppMode | from models.model import AppMode | ||||
| from services.app_generate_service import AppGenerateService | from services.app_generate_service import AppGenerateService | ||||
| from services.errors.app import WorkflowHashNotEqualError | from services.errors.app import WorkflowHashNotEqualError | ||||
| else: | else: | ||||
| abort(415) | abort(415) | ||||
| if not isinstance(current_user, Account): | |||||
| raise Forbidden() | |||||
| workflow_service = WorkflowService() | workflow_service = WorkflowService() | ||||
| try: | try: | ||||
| if not current_user.is_editor: | if not current_user.is_editor: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| if not isinstance(current_user, Account): | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("inputs", type=dict, location="json") | parser.add_argument("inputs", type=dict, location="json") | ||||
| parser.add_argument("query", type=str, required=True, location="json", default="") | parser.add_argument("query", type=str, required=True, location="json", default="") | ||||
| raise ConversationCompletedError() | raise ConversationCompletedError() | ||||
| except ValueError as e: | except ValueError as e: | ||||
| raise e | raise e | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| logging.exception("internal server error.") | logging.exception("internal server error.") | ||||
| raise InternalServerError() | raise InternalServerError() | ||||
| if not current_user.is_editor: | if not current_user.is_editor: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| if not isinstance(current_user, Account): | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("inputs", type=dict, location="json") | parser.add_argument("inputs", type=dict, location="json") | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| raise ConversationCompletedError() | raise ConversationCompletedError() | ||||
| except ValueError as e: | except ValueError as e: | ||||
| raise e | raise e | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| logging.exception("internal server error.") | logging.exception("internal server error.") | ||||
| raise InternalServerError() | raise InternalServerError() | ||||
| if not current_user.is_editor: | if not current_user.is_editor: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| if not isinstance(current_user, Account): | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("inputs", type=dict, location="json") | parser.add_argument("inputs", type=dict, location="json") | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| raise ConversationCompletedError() | raise ConversationCompletedError() | ||||
| except ValueError as e: | except ValueError as e: | ||||
| raise e | raise e | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| logging.exception("internal server error.") | logging.exception("internal server error.") | ||||
| raise InternalServerError() | raise InternalServerError() | ||||
| if not current_user.is_editor: | if not current_user.is_editor: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| if not isinstance(current_user, Account): | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | ||||
| parser.add_argument("files", type=list, required=False, location="json") | parser.add_argument("files", type=list, required=False, location="json") | ||||
| if not current_user.is_editor: | if not current_user.is_editor: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| if not isinstance(current_user, Account): | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| inputs = args.get("inputs") | |||||
| if inputs == None: | |||||
| raise ValueError("missing inputs") | |||||
| workflow_service = WorkflowService() | workflow_service = WorkflowService() | ||||
| workflow_node_execution = workflow_service.run_draft_workflow_node( | workflow_node_execution = workflow_service.run_draft_workflow_node( | ||||
| app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user | |||||
| app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user | |||||
| ) | ) | ||||
| return workflow_node_execution | return workflow_node_execution | ||||
| if not current_user.is_editor: | if not current_user.is_editor: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| if not isinstance(current_user, Account): | |||||
| raise Forbidden() | |||||
| workflow_service = WorkflowService() | workflow_service = WorkflowService() | ||||
| workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) | workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) | ||||
| if not current_user.is_editor: | if not current_user.is_editor: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| if not isinstance(current_user, Account): | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("q", type=str, location="args") | parser.add_argument("q", type=str, location="args") | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| q = args.get("q") | |||||
| filters = None | filters = None | ||||
| if args.get("q"): | |||||
| if q: | |||||
| try: | try: | ||||
| filters = json.loads(args.get("q", "")) | filters = json.loads(args.get("q", "")) | ||||
| except json.JSONDecodeError: | except json.JSONDecodeError: | ||||
| if not current_user.is_editor: | if not current_user.is_editor: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| if not isinstance(current_user, Account): | |||||
| raise Forbidden() | |||||
| if request.data: | if request.data: | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("name", type=str, required=False, nullable=True, location="json") | parser.add_argument("name", type=str, required=False, nullable=True, location="json") | 
| from flask import request | from flask import request | ||||
| from flask_restful import Resource, reqparse # type: ignore | from flask_restful import Resource, reqparse # type: ignore | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session | |||||
| from constants.languages import languages | from constants.languages import languages | ||||
| from controllers.console import api | from controllers.console import api | ||||
| else: | else: | ||||
| language = "en-US" | language = "en-US" | ||||
| account = Account.query.filter_by(email=args["email"]).first() | |||||
| with Session(db.engine) as session: | |||||
| account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() | |||||
| token = None | token = None | ||||
| if account is None: | if account is None: | ||||
| if FeatureService.get_system_features().is_allow_register: | if FeatureService.get_system_features().is_allow_register: | ||||
| password_hashed = hash_password(new_password, salt) | password_hashed = hash_password(new_password, salt) | ||||
| base64_password_hashed = base64.b64encode(password_hashed).decode() | base64_password_hashed = base64.b64encode(password_hashed).decode() | ||||
| account = Account.query.filter_by(email=reset_data.get("email")).first() | |||||
| with Session(db.engine) as session: | |||||
| account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none() | |||||
| if account: | if account: | ||||
| account.password = base64_password_hashed | account.password = base64_password_hashed | ||||
| account.password_salt = base64_salt | account.password_salt = base64_salt | ||||
| ) | ) | ||||
| except WorkSpaceNotAllowedCreateError: | except WorkSpaceNotAllowedCreateError: | ||||
| pass | pass | ||||
| except AccountRegisterError as are: | |||||
| except AccountRegisterError: | |||||
| raise AccountInFreezeError() | raise AccountInFreezeError() | ||||
| return {"result": "success"} | return {"result": "success"} | 
| import requests | import requests | ||||
| from flask import current_app, redirect, request | from flask import current_app, redirect, request | ||||
| from flask_restful import Resource # type: ignore | from flask_restful import Resource # type: ignore | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session | |||||
| from werkzeug.exceptions import Unauthorized | from werkzeug.exceptions import Unauthorized | ||||
| from configs import dify_config | from configs import dify_config | ||||
| account: Optional[Account] = Account.get_by_openid(provider, user_info.id) | account: Optional[Account] = Account.get_by_openid(provider, user_info.id) | ||||
| if not account: | if not account: | ||||
| account = Account.query.filter_by(email=user_info.email).first() | |||||
| with Session(db.engine) as session: | |||||
| account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none() | |||||
| return account | return account | ||||
| from flask import request | from flask import request | ||||
| from flask_login import current_user # type: ignore | from flask_login import current_user # type: ignore | ||||
| from flask_restful import Resource, marshal_with, reqparse # type: ignore | from flask_restful import Resource, marshal_with, reqparse # type: ignore | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session | |||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from controllers.console import api | from controllers.console import api | ||||
| def patch(self, binding_id, action): | def patch(self, binding_id, action): | ||||
| binding_id = str(binding_id) | binding_id = str(binding_id) | ||||
| action = str(action) | action = str(action) | ||||
| data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() | |||||
| with Session(db.engine) as session: | |||||
| data_source_binding = session.execute( | |||||
| select(DataSourceOauthBinding).filter_by(id=binding_id) | |||||
| ).scalar_one_or_none() | |||||
| if data_source_binding is None: | if data_source_binding is None: | ||||
| raise NotFound("Data source binding not found.") | raise NotFound("Data source binding not found.") | ||||
| # enable binding | # enable binding | ||||
| def get(self): | def get(self): | ||||
| dataset_id = request.args.get("dataset_id", default=None, type=str) | dataset_id = request.args.get("dataset_id", default=None, type=str) | ||||
| exist_page_ids = [] | exist_page_ids = [] | ||||
| # import notion in the exist dataset | |||||
| if dataset_id: | |||||
| dataset = DatasetService.get_dataset(dataset_id) | |||||
| if not dataset: | |||||
| raise NotFound("Dataset not found.") | |||||
| if dataset.data_source_type != "notion_import": | |||||
| raise ValueError("Dataset is not notion type.") | |||||
| documents = Document.query.filter_by( | |||||
| dataset_id=dataset_id, | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| data_source_type="notion_import", | |||||
| enabled=True, | |||||
| with Session(db.engine) as session: | |||||
| # import notion in the exist dataset | |||||
| if dataset_id: | |||||
| dataset = DatasetService.get_dataset(dataset_id) | |||||
| if not dataset: | |||||
| raise NotFound("Dataset not found.") | |||||
| if dataset.data_source_type != "notion_import": | |||||
| raise ValueError("Dataset is not notion type.") | |||||
| documents = session.execute( | |||||
| select(Document).filter_by( | |||||
| dataset_id=dataset_id, | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| data_source_type="notion_import", | |||||
| enabled=True, | |||||
| ) | |||||
| ).all() | |||||
| if documents: | |||||
| for document in documents: | |||||
| data_source_info = json.loads(document.data_source_info) | |||||
| exist_page_ids.append(data_source_info["notion_page_id"]) | |||||
| # get all authorized pages | |||||
| data_source_bindings = session.scalars( | |||||
| select(DataSourceOauthBinding).filter_by( | |||||
| tenant_id=current_user.current_tenant_id, provider="notion", disabled=False | |||||
| ) | |||||
| ).all() | ).all() | ||||
| if documents: | |||||
| for document in documents: | |||||
| data_source_info = json.loads(document.data_source_info) | |||||
| exist_page_ids.append(data_source_info["notion_page_id"]) | |||||
| # get all authorized pages | |||||
| data_source_bindings = DataSourceOauthBinding.query.filter_by( | |||||
| tenant_id=current_user.current_tenant_id, provider="notion", disabled=False | |||||
| ).all() | |||||
| if not data_source_bindings: | |||||
| return {"notion_info": []}, 200 | |||||
| pre_import_info_list = [] | |||||
| for data_source_binding in data_source_bindings: | |||||
| source_info = data_source_binding.source_info | |||||
| pages = source_info["pages"] | |||||
| # Filter out already bound pages | |||||
| for page in pages: | |||||
| if page["page_id"] in exist_page_ids: | |||||
| page["is_bound"] = True | |||||
| else: | |||||
| page["is_bound"] = False | |||||
| pre_import_info = { | |||||
| "workspace_name": source_info["workspace_name"], | |||||
| "workspace_icon": source_info["workspace_icon"], | |||||
| "workspace_id": source_info["workspace_id"], | |||||
| "pages": pages, | |||||
| } | |||||
| pre_import_info_list.append(pre_import_info) | |||||
| return {"notion_info": pre_import_info_list}, 200 | |||||
| if not data_source_bindings: | |||||
| return {"notion_info": []}, 200 | |||||
| pre_import_info_list = [] | |||||
| for data_source_binding in data_source_bindings: | |||||
| source_info = data_source_binding.source_info | |||||
| pages = source_info["pages"] | |||||
| # Filter out already bound pages | |||||
| for page in pages: | |||||
| if page["page_id"] in exist_page_ids: | |||||
| page["is_bound"] = True | |||||
| else: | |||||
| page["is_bound"] = False | |||||
| pre_import_info = { | |||||
| "workspace_name": source_info["workspace_name"], | |||||
| "workspace_icon": source_info["workspace_icon"], | |||||
| "workspace_id": source_info["workspace_id"], | |||||
| "pages": pages, | |||||
| } | |||||
| pre_import_info_list.append(pre_import_info) | |||||
| return {"notion_info": pre_import_info_list}, 200 | |||||
| class DataSourceNotionApi(Resource): | class DataSourceNotionApi(Resource): | ||||
| def get(self, workspace_id, page_id, page_type): | def get(self, workspace_id, page_id, page_type): | ||||
| workspace_id = str(workspace_id) | workspace_id = str(workspace_id) | ||||
| page_id = str(page_id) | page_id = str(page_id) | ||||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||||
| db.and_( | |||||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||||
| DataSourceOauthBinding.provider == "notion", | |||||
| DataSourceOauthBinding.disabled == False, | |||||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||||
| ) | |||||
| ).first() | |||||
| with Session(db.engine) as session: | |||||
| data_source_binding = session.execute( | |||||
| select(DataSourceOauthBinding).filter( | |||||
| db.and_( | |||||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||||
| DataSourceOauthBinding.provider == "notion", | |||||
| DataSourceOauthBinding.disabled == False, | |||||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||||
| ) | |||||
| ) | |||||
| ).scalar_one_or_none() | |||||
| if not data_source_binding: | if not data_source_binding: | ||||
| raise NotFound("Data source binding not found.") | raise NotFound("Data source binding not found.") | ||||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | ||||
| from core.indexing_runner import IndexingRunner | from core.indexing_runner import IndexingRunner | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.plugin.entities.plugin import ModelProviderID | |||||
| from core.provider_manager import ProviderManager | from core.provider_manager import ProviderManager | ||||
| from core.rag.datasource.vdb.vector_type import VectorType | from core.rag.datasource.vdb.vector_type import VectorType | ||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | from core.rag.extractor.entity.extract_setting import ExtractSetting | ||||
| data = marshal(datasets, dataset_detail_fields) | data = marshal(datasets, dataset_detail_fields) | ||||
| for item in data: | for item in data: | ||||
| # convert embedding_model_provider to plugin standard format | |||||
| if item["indexing_technique"] == "high_quality": | if item["indexing_technique"] == "high_quality": | ||||
| item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) | |||||
| item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" | item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" | ||||
| if item_model in model_names: | if item_model in model_names: | ||||
| item["embedding_available"] = True | item["embedding_available"] = True | 
| from flask_login import current_user # type: ignore | from flask_login import current_user # type: ignore | ||||
| from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore | from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore | ||||
| from sqlalchemy import asc, desc | from sqlalchemy import asc, desc | ||||
| from transformers.hf_argparser import string_to_bool # type: ignore | |||||
| from werkzeug.exceptions import Forbidden, NotFound | from werkzeug.exceptions import Forbidden, NotFound | ||||
| import services | import services | ||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | from core.model_runtime.errors.invoke import InvokeAuthorizationError | ||||
| from core.plugin.manager.exc import PluginDaemonClientSideError | |||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | from core.rag.extractor.entity.extract_setting import ExtractSetting | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| sort = request.args.get("sort", default="-created_at", type=str) | sort = request.args.get("sort", default="-created_at", type=str) | ||||
| # "yes", "true", "t", "y", "1" convert to True, while others convert to False. | # "yes", "true", "t", "y", "1" convert to True, while others convert to False. | ||||
| try: | try: | ||||
| fetch = string_to_bool(request.args.get("fetch", default="false")) | |||||
| except (ArgumentTypeError, ValueError, Exception) as e: | |||||
| fetch_val = request.args.get("fetch", default="false") | |||||
| if isinstance(fetch_val, bool): | |||||
| fetch = fetch_val | |||||
| else: | |||||
| if fetch_val.lower() in ("yes", "true", "t", "y", "1"): | |||||
| fetch = True | |||||
| elif fetch_val.lower() in ("no", "false", "f", "n", "0"): | |||||
| fetch = False | |||||
| else: | |||||
| raise ArgumentTypeError( | |||||
| f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 " | |||||
| f"(case insensitive)." | |||||
| ) | |||||
| except (ArgumentTypeError, ValueError, Exception): | |||||
| fetch = False | fetch = False | ||||
| dataset = DatasetService.get_dataset(dataset_id) | dataset = DatasetService.get_dataset(dataset_id) | ||||
| if not dataset: | if not dataset: | ||||
| ) | ) | ||||
| except ProviderTokenNotInitError as ex: | except ProviderTokenNotInitError as ex: | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| except PluginDaemonClientSideError as ex: | |||||
| raise ProviderNotInitializeError(ex.description) | |||||
| except Exception as e: | except Exception as e: | ||||
| raise IndexingEstimateError(str(e)) | raise IndexingEstimateError(str(e)) | ||||
| ) | ) | ||||
| except ProviderTokenNotInitError as ex: | except ProviderTokenNotInitError as ex: | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| except PluginDaemonClientSideError as ex: | |||||
| raise ProviderNotInitializeError(ex.description) | |||||
| except Exception as e: | except Exception as e: | ||||
| raise IndexingEstimateError(str(e)) | raise IndexingEstimateError(str(e)) | ||||
| from flask import session | from flask import session | ||||
| from flask_restful import Resource, reqparse # type: ignore | from flask_restful import Resource, reqparse # type: ignore | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from extensions.ext_database import db | |||||
| from libs.helper import StrLen | from libs.helper import StrLen | ||||
| from models.model import DifySetup | from models.model import DifySetup | ||||
| from services.account_service import TenantService | from services.account_service import TenantService | ||||
| def get_init_validate_status(): | def get_init_validate_status(): | ||||
| if dify_config.EDITION == "SELF_HOSTED": | if dify_config.EDITION == "SELF_HOSTED": | ||||
| if os.environ.get("INIT_PASSWORD"): | if os.environ.get("INIT_PASSWORD"): | ||||
| return session.get("is_init_validated") or DifySetup.query.first() | |||||
| if session.get("is_init_validated"): | |||||
| return True | |||||
| with Session(db.engine) as db_session: | |||||
| return db_session.execute(select(DifySetup)).scalar_one_or_none() | |||||
| return True | return True | ||||
| from configs import dify_config | from configs import dify_config | ||||
| from libs.helper import StrLen, email, extract_remote_ip | from libs.helper import StrLen, email, extract_remote_ip | ||||
| from libs.password import valid_password | from libs.password import valid_password | ||||
| from models.model import DifySetup | |||||
| from models.model import DifySetup, db | |||||
| from services.account_service import RegisterService, TenantService | from services.account_service import RegisterService, TenantService | ||||
| from . import api | from . import api | ||||
| def get_setup_status(): | def get_setup_status(): | ||||
| if dify_config.EDITION == "SELF_HOSTED": | if dify_config.EDITION == "SELF_HOSTED": | ||||
| return DifySetup.query.first() | |||||
| return True | |||||
| return db.session.query(DifySetup).first() | |||||
| else: | |||||
| return True | |||||
| api.add_resource(SetupApi, "/setup") | api.add_resource(SetupApi, "/setup") | 
| from functools import wraps | |||||
| from flask_login import current_user # type: ignore | |||||
| from sqlalchemy.orm import Session | |||||
| from werkzeug.exceptions import Forbidden | |||||
| from extensions.ext_database import db | |||||
| from models.account import TenantPluginPermission | |||||
| def plugin_permission_required( | |||||
| install_required: bool = False, | |||||
| debug_required: bool = False, | |||||
| ): | |||||
| def interceptor(view): | |||||
| @wraps(view) | |||||
| def decorated(*args, **kwargs): | |||||
| user = current_user | |||||
| tenant_id = user.current_tenant_id | |||||
| with Session(db.engine) as session: | |||||
| permission = ( | |||||
| session.query(TenantPluginPermission) | |||||
| .filter( | |||||
| TenantPluginPermission.tenant_id == tenant_id, | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| if not permission: | |||||
| # no permission set, allow access for everyone | |||||
| return view(*args, **kwargs) | |||||
| if install_required: | |||||
| if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY: | |||||
| raise Forbidden() | |||||
| if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS: | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE: | |||||
| pass | |||||
| if debug_required: | |||||
| if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY: | |||||
| raise Forbidden() | |||||
| if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS: | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE: | |||||
| pass | |||||
| return view(*args, **kwargs) | |||||
| return decorated | |||||
| return interceptor | 
| from flask_login import current_user # type: ignore | |||||
| from flask_restful import Resource # type: ignore | |||||
| from controllers.console import api | |||||
| from controllers.console.wraps import account_initialization_required, setup_required | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||||
| from libs.login import login_required | |||||
| from services.agent_service import AgentService | |||||
| class AgentProviderListApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self): | |||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) | |||||
| class AgentProviderApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self, provider_name: str): | |||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name)) | |||||
| api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers") | |||||
| api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>") | 
| from flask_login import current_user # type: ignore | |||||
| from flask_restful import Resource, reqparse # type: ignore | |||||
| from werkzeug.exceptions import Forbidden | |||||
| from controllers.console import api | |||||
| from controllers.console.wraps import account_initialization_required, setup_required | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||||
| from libs.login import login_required | |||||
| from services.plugin.endpoint_service import EndpointService | |||||
| class EndpointCreateApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self): | |||||
| user = current_user | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("plugin_unique_identifier", type=str, required=True) | |||||
| parser.add_argument("settings", type=dict, required=True) | |||||
| parser.add_argument("name", type=str, required=True) | |||||
| args = parser.parse_args() | |||||
| plugin_unique_identifier = args["plugin_unique_identifier"] | |||||
| settings = args["settings"] | |||||
| name = args["name"] | |||||
| return { | |||||
| "success": EndpointService.create_endpoint( | |||||
| tenant_id=user.current_tenant_id, | |||||
| user_id=user.id, | |||||
| plugin_unique_identifier=plugin_unique_identifier, | |||||
| name=name, | |||||
| settings=settings, | |||||
| ) | |||||
| } | |||||
| class EndpointListApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self): | |||||
| user = current_user | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("page", type=int, required=True, location="args") | |||||
| parser.add_argument("page_size", type=int, required=True, location="args") | |||||
| args = parser.parse_args() | |||||
| page = args["page"] | |||||
| page_size = args["page_size"] | |||||
| return jsonable_encoder( | |||||
| { | |||||
| "endpoints": EndpointService.list_endpoints( | |||||
| tenant_id=user.current_tenant_id, | |||||
| user_id=user.id, | |||||
| page=page, | |||||
| page_size=page_size, | |||||
| ) | |||||
| } | |||||
| ) | |||||
| class EndpointListForSinglePluginApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self): | |||||
| user = current_user | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("page", type=int, required=True, location="args") | |||||
| parser.add_argument("page_size", type=int, required=True, location="args") | |||||
| parser.add_argument("plugin_id", type=str, required=True, location="args") | |||||
| args = parser.parse_args() | |||||
| page = args["page"] | |||||
| page_size = args["page_size"] | |||||
| plugin_id = args["plugin_id"] | |||||
| return jsonable_encoder( | |||||
| { | |||||
| "endpoints": EndpointService.list_endpoints_for_single_plugin( | |||||
| tenant_id=user.current_tenant_id, | |||||
| user_id=user.id, | |||||
| plugin_id=plugin_id, | |||||
| page=page, | |||||
| page_size=page_size, | |||||
| ) | |||||
| } | |||||
| ) | |||||
| class EndpointDeleteApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self): | |||||
| user = current_user | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("endpoint_id", type=str, required=True) | |||||
| args = parser.parse_args() | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| endpoint_id = args["endpoint_id"] | |||||
| return { | |||||
| "success": EndpointService.delete_endpoint( | |||||
| tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id | |||||
| ) | |||||
| } | |||||
| class EndpointUpdateApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self): | |||||
| user = current_user | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("endpoint_id", type=str, required=True) | |||||
| parser.add_argument("settings", type=dict, required=True) | |||||
| parser.add_argument("name", type=str, required=True) | |||||
| args = parser.parse_args() | |||||
| endpoint_id = args["endpoint_id"] | |||||
| settings = args["settings"] | |||||
| name = args["name"] | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| return { | |||||
| "success": EndpointService.update_endpoint( | |||||
| tenant_id=user.current_tenant_id, | |||||
| user_id=user.id, | |||||
| endpoint_id=endpoint_id, | |||||
| name=name, | |||||
| settings=settings, | |||||
| ) | |||||
| } | |||||
| class EndpointEnableApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self): | |||||
| user = current_user | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("endpoint_id", type=str, required=True) | |||||
| args = parser.parse_args() | |||||
| endpoint_id = args["endpoint_id"] | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| return { | |||||
| "success": EndpointService.enable_endpoint( | |||||
| tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id | |||||
| ) | |||||
| } | |||||
| class EndpointDisableApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self): | |||||
| user = current_user | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("endpoint_id", type=str, required=True) | |||||
| args = parser.parse_args() | |||||
| endpoint_id = args["endpoint_id"] | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| return { | |||||
| "success": EndpointService.disable_endpoint( | |||||
| tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id | |||||
| ) | |||||
| } | |||||
| api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create") | |||||
| api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list") | |||||
| api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin") | |||||
| api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete") | |||||
| api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update") | |||||
| api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable") | |||||
| api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable") | 
| # Load Balancing Config | # Load Balancing Config | ||||
| api.add_resource( | api.add_resource( | ||||
| LoadBalancingCredentialsValidateApi, | LoadBalancingCredentialsValidateApi, | ||||
| "/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate", | |||||
| "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate", | |||||
| ) | ) | ||||
| api.add_resource( | api.add_resource( | ||||
| LoadBalancingConfigCredentialsValidateApi, | LoadBalancingConfigCredentialsValidateApi, | ||||
| "/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate", | |||||
| "/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate", | |||||
| ) | ) | 
| response = {"result": "success" if result else "error"} | response = {"result": "success" if result else "error"} | ||||
| if not result: | if not result: | ||||
| response["error"] = error | |||||
| response["error"] = error or "Unknown error" | |||||
| return response | return response | ||||
| Get model provider icon | Get model provider icon | ||||
| """ | """ | ||||
| def get(self, provider: str, icon_type: str, lang: str): | |||||
| def get(self, tenant_id: str, provider: str, icon_type: str, lang: str): | |||||
| model_provider_service = ModelProviderService() | model_provider_service = ModelProviderService() | ||||
| icon, mimetype = model_provider_service.get_model_provider_icon( | icon, mimetype = model_provider_service.get_model_provider_icon( | ||||
| tenant_id=tenant_id, | |||||
| provider=provider, | provider=provider, | ||||
| icon_type=icon_type, | icon_type=icon_type, | ||||
| lang=lang, | lang=lang, | ||||
| return data | return data | ||||
| class ModelProviderFreeQuotaSubmitApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self, provider: str): | |||||
| model_provider_service = ModelProviderService() | |||||
| result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider) | |||||
| return result | |||||
| class ModelProviderFreeQuotaQualificationVerifyApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self, provider: str): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("token", type=str, required=False, nullable=True, location="args") | |||||
| args = parser.parse_args() | |||||
| model_provider_service = ModelProviderService() | |||||
| result = model_provider_service.free_quota_qualification_verify( | |||||
| tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"] | |||||
| ) | |||||
| return result | |||||
| api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") | api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") | ||||
| api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials") | |||||
| api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate") | |||||
| api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>") | |||||
| api.add_resource( | |||||
| ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>" | |||||
| ) | |||||
| api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials") | |||||
| api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate") | |||||
| api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>") | |||||
| api.add_resource( | api.add_resource( | ||||
| PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type" | |||||
| ) | |||||
| api.add_resource( | |||||
| ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url" | |||||
| ) | |||||
| api.add_resource( | |||||
| ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit" | |||||
| PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type" | |||||
| ) | ) | ||||
| api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url") | |||||
| api.add_resource( | api.add_resource( | ||||
| ModelProviderFreeQuotaQualificationVerifyApi, | |||||
| "/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify", | |||||
| ModelProviderIconApi, | |||||
| "/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>", | |||||
| ) | ) | 
| response = {"result": "success" if result else "error"} | response = {"result": "success" if result else "error"} | ||||
| if not result: | if not result: | ||||
| response["error"] = error | |||||
| response["error"] = error or "" | |||||
| return response | return response | ||||
| return jsonable_encoder({"data": models}) | return jsonable_encoder({"data": models}) | ||||
| api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models") | |||||
| api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models") | |||||
| api.add_resource( | api.add_resource( | ||||
| ModelProviderModelEnableApi, | ModelProviderModelEnableApi, | ||||
| "/workspaces/current/model-providers/<string:provider>/models/enable", | |||||
| "/workspaces/current/model-providers/<path:provider>/models/enable", | |||||
| endpoint="model-provider-model-enable", | endpoint="model-provider-model-enable", | ||||
| ) | ) | ||||
| api.add_resource( | api.add_resource( | ||||
| ModelProviderModelDisableApi, | ModelProviderModelDisableApi, | ||||
| "/workspaces/current/model-providers/<string:provider>/models/disable", | |||||
| "/workspaces/current/model-providers/<path:provider>/models/disable", | |||||
| endpoint="model-provider-model-disable", | endpoint="model-provider-model-disable", | ||||
| ) | ) | ||||
| api.add_resource( | api.add_resource( | ||||
| ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials" | |||||
| ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials" | |||||
| ) | ) | ||||
| api.add_resource( | api.add_resource( | ||||
| ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate" | |||||
| ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate" | |||||
| ) | ) | ||||
| api.add_resource( | api.add_resource( | ||||
| ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules" | |||||
| ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules" | |||||
| ) | ) | ||||
| api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>") | api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>") | ||||
| api.add_resource(DefaultModelApi, "/workspaces/current/default-model") | api.add_resource(DefaultModelApi, "/workspaces/current/default-model") | 
| import io | |||||
| from flask import request, send_file | |||||
| from flask_login import current_user # type: ignore | |||||
| from flask_restful import Resource, reqparse # type: ignore | |||||
| from werkzeug.exceptions import Forbidden | |||||
| from configs import dify_config | |||||
| from controllers.console import api | |||||
| from controllers.console.workspace import plugin_permission_required | |||||
| from controllers.console.wraps import account_initialization_required, setup_required | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||||
| from core.plugin.manager.exc import PluginDaemonClientSideError | |||||
| from libs.login import login_required | |||||
| from models.account import TenantPluginPermission | |||||
| from services.plugin.plugin_permission_service import PluginPermissionService | |||||
| from services.plugin.plugin_service import PluginService | |||||
| class PluginDebuggingKeyApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(debug_required=True) | |||||
| def get(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| try: | |||||
| return { | |||||
| "key": PluginService.get_debugging_key(tenant_id), | |||||
| "host": dify_config.PLUGIN_REMOTE_INSTALL_HOST, | |||||
| "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT, | |||||
| } | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| class PluginListApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| try: | |||||
| plugins = PluginService.list(tenant_id) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| return jsonable_encoder({"plugins": plugins}) | |||||
| class PluginListInstallationsFromIdsApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("plugin_ids", type=list, required=True, location="json") | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"]) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| return jsonable_encoder({"plugins": plugins}) | |||||
| class PluginIconApi(Resource): | |||||
| @setup_required | |||||
| def get(self): | |||||
| req = reqparse.RequestParser() | |||||
| req.add_argument("tenant_id", type=str, required=True, location="args") | |||||
| req.add_argument("filename", type=str, required=True, location="args") | |||||
| args = req.parse_args() | |||||
| try: | |||||
| icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"]) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE | |||||
| return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) | |||||
| class PluginUploadFromPkgApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(install_required=True) | |||||
| def post(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| file = request.files["pkg"] | |||||
| # check file size | |||||
| if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE: | |||||
| raise ValueError("File size exceeds the maximum allowed size") | |||||
| content = file.read() | |||||
| try: | |||||
| response = PluginService.upload_pkg(tenant_id, content) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| return jsonable_encoder(response) | |||||
| class PluginUploadFromGithubApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(install_required=True) | |||||
| def post(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("repo", type=str, required=True, location="json") | |||||
| parser.add_argument("version", type=str, required=True, location="json") | |||||
| parser.add_argument("package", type=str, required=True, location="json") | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"]) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| return jsonable_encoder(response) | |||||
| class PluginUploadFromBundleApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(install_required=True) | |||||
| def post(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| file = request.files["bundle"] | |||||
| # check file size | |||||
| if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE: | |||||
| raise ValueError("File size exceeds the maximum allowed size") | |||||
| content = file.read() | |||||
| try: | |||||
| response = PluginService.upload_bundle(tenant_id, content) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| return jsonable_encoder(response) | |||||
| class PluginInstallFromPkgApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(install_required=True) | |||||
| def post(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") | |||||
| args = parser.parse_args() | |||||
| # check if all plugin_unique_identifiers are valid string | |||||
| for plugin_unique_identifier in args["plugin_unique_identifiers"]: | |||||
| if not isinstance(plugin_unique_identifier, str): | |||||
| raise ValueError("Invalid plugin unique identifier") | |||||
| try: | |||||
| response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"]) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| return jsonable_encoder(response) | |||||
| class PluginInstallFromGithubApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(install_required=True) | |||||
| def post(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("repo", type=str, required=True, location="json") | |||||
| parser.add_argument("version", type=str, required=True, location="json") | |||||
| parser.add_argument("package", type=str, required=True, location="json") | |||||
| parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json") | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| response = PluginService.install_from_github( | |||||
| tenant_id, | |||||
| args["plugin_unique_identifier"], | |||||
| args["repo"], | |||||
| args["version"], | |||||
| args["package"], | |||||
| ) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| return jsonable_encoder(response) | |||||
| class PluginInstallFromMarketplaceApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(install_required=True) | |||||
| def post(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") | |||||
| args = parser.parse_args() | |||||
| # check if all plugin_unique_identifiers are valid string | |||||
| for plugin_unique_identifier in args["plugin_unique_identifiers"]: | |||||
| if not isinstance(plugin_unique_identifier, str): | |||||
| raise ValueError("Invalid plugin unique identifier") | |||||
| try: | |||||
| response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"]) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| return jsonable_encoder(response) | |||||
| class PluginFetchManifestApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(debug_required=True) | |||||
| def get(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| return jsonable_encoder( | |||||
| { | |||||
| "manifest": PluginService.fetch_plugin_manifest( | |||||
| tenant_id, args["plugin_unique_identifier"] | |||||
| ).model_dump() | |||||
| } | |||||
| ) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| class PluginFetchInstallTasksApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(debug_required=True) | |||||
| def get(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("page", type=int, required=True, location="args") | |||||
| parser.add_argument("page_size", type=int, required=True, location="args") | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| return jsonable_encoder( | |||||
| {"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])} | |||||
| ) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| class PluginFetchInstallTaskApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(debug_required=True) | |||||
| def get(self, task_id: str): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| try: | |||||
| return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| class PluginDeleteInstallTaskApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(debug_required=True) | |||||
| def post(self, task_id: str): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| try: | |||||
| return {"success": PluginService.delete_install_task(tenant_id, task_id)} | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| class PluginDeleteAllInstallTaskItemsApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(debug_required=True) | |||||
| def post(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| try: | |||||
| return {"success": PluginService.delete_all_install_task_items(tenant_id)} | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| class PluginDeleteInstallTaskItemApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(debug_required=True) | |||||
| def post(self, task_id: str, identifier: str): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| try: | |||||
| return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| class PluginUpgradeFromMarketplaceApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(debug_required=True) | |||||
| def post(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") | |||||
| parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| return jsonable_encoder( | |||||
| PluginService.upgrade_plugin_with_marketplace( | |||||
| tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"] | |||||
| ) | |||||
| ) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| class PluginUpgradeFromGithubApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(debug_required=True) | |||||
| def post(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") | |||||
| parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") | |||||
| parser.add_argument("repo", type=str, required=True, location="json") | |||||
| parser.add_argument("version", type=str, required=True, location="json") | |||||
| parser.add_argument("package", type=str, required=True, location="json") | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| return jsonable_encoder( | |||||
| PluginService.upgrade_plugin_with_github( | |||||
| tenant_id, | |||||
| args["original_plugin_unique_identifier"], | |||||
| args["new_plugin_unique_identifier"], | |||||
| args["repo"], | |||||
| args["version"], | |||||
| args["package"], | |||||
| ) | |||||
| ) | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| class PluginUninstallApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @plugin_permission_required(debug_required=True) | |||||
| def post(self): | |||||
| req = reqparse.RequestParser() | |||||
| req.add_argument("plugin_installation_id", type=str, required=True, location="json") | |||||
| args = req.parse_args() | |||||
| tenant_id = current_user.current_tenant_id | |||||
| try: | |||||
| return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])} | |||||
| except PluginDaemonClientSideError as e: | |||||
| raise ValueError(e) | |||||
| class PluginChangePermissionApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self): | |||||
| user = current_user | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| req = reqparse.RequestParser() | |||||
| req.add_argument("install_permission", type=str, required=True, location="json") | |||||
| req.add_argument("debug_permission", type=str, required=True, location="json") | |||||
| args = req.parse_args() | |||||
| install_permission = TenantPluginPermission.InstallPermission(args["install_permission"]) | |||||
| debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"]) | |||||
| tenant_id = user.current_tenant_id | |||||
| return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} | |||||
| class PluginFetchPermissionApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| permission = PluginPermissionService.get_permission(tenant_id) | |||||
| if not permission: | |||||
| return jsonable_encoder( | |||||
| { | |||||
| "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, | |||||
| "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, | |||||
| } | |||||
| ) | |||||
| return jsonable_encoder( | |||||
| { | |||||
| "install_permission": permission.install_permission, | |||||
| "debug_permission": permission.debug_permission, | |||||
| } | |||||
| ) | |||||
| api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key") | |||||
| api.add_resource(PluginListApi, "/workspaces/current/plugin/list") | |||||
| api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids") | |||||
| api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon") | |||||
| api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg") | |||||
| api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github") | |||||
| api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle") | |||||
| api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg") | |||||
| api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github") | |||||
| api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace") | |||||
| api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github") | |||||
| api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace") | |||||
| api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest") | |||||
| api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks") | |||||
| api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>") | |||||
| api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete") | |||||
| api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all") | |||||
| api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>") | |||||
| api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall") | |||||
| api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change") | |||||
| api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") | 
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| req = reqparse.RequestParser() | req = reqparse.RequestParser() | ||||
| req.add_argument( | req.add_argument( | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, provider): | def get(self, provider): | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user = current_user | |||||
| tenant_id = user.current_tenant_id | |||||
| return jsonable_encoder( | return jsonable_encoder( | ||||
| BuiltinToolManageService.list_builtin_tool_provider_tools( | BuiltinToolManageService.list_builtin_tool_provider_tools( | ||||
| user_id, | |||||
| tenant_id, | tenant_id, | ||||
| provider, | provider, | ||||
| ) | ) | ||||
| ) | ) | ||||
| class ToolBuiltinProviderInfoApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self, provider): | |||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider)) | |||||
| class ToolBuiltinProviderDeleteApi(Resource): | class ToolBuiltinProviderDeleteApi(Resource): | ||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self, provider): | def post(self, provider): | ||||
| if not current_user.is_admin_or_owner: | |||||
| user = current_user | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | raise Forbidden() | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| return BuiltinToolManageService.delete_builtin_tool_provider( | return BuiltinToolManageService.delete_builtin_tool_provider( | ||||
| user_id, | user_id, | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self, provider): | def post(self, provider): | ||||
| if not current_user.is_admin_or_owner: | |||||
| user = current_user | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | raise Forbidden() | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self): | def post(self): | ||||
| if not current_user.is_admin_or_owner: | |||||
| user = current_user | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | raise Forbidden() | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("url", type=str, required=True, nullable=False, location="args") | parser.add_argument("url", type=str, required=True, nullable=False, location="args") | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| return ApiToolManageService.get_api_tool_provider_remote_schema( | return ApiToolManageService.get_api_tool_provider_remote_schema( | ||||
| current_user.id, | |||||
| current_user.current_tenant_id, | |||||
| user_id, | |||||
| tenant_id, | |||||
| args["url"], | args["url"], | ||||
| ) | ) | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self): | def post(self): | ||||
| if not current_user.is_admin_or_owner: | |||||
| user = current_user | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | raise Forbidden() | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self): | def post(self): | ||||
| if not current_user.is_admin_or_owner: | |||||
| user = current_user | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | raise Forbidden() | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, provider): | def get(self, provider): | ||||
| return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) | |||||
| user = current_user | |||||
| tenant_id = user.current_tenant_id | |||||
| return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) | |||||
| class ToolApiProviderSchemaApi(Resource): | class ToolApiProviderSchemaApi(Resource): | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self): | def post(self): | ||||
| if not current_user.is_admin_or_owner: | |||||
| user = current_user | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | raise Forbidden() | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| reqparser = reqparse.RequestParser() | reqparser = reqparse.RequestParser() | ||||
| reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") | reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self): | def post(self): | ||||
| if not current_user.is_admin_or_owner: | |||||
| user = current_user | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | raise Forbidden() | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| reqparser = reqparse.RequestParser() | reqparser = reqparse.RequestParser() | ||||
| reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") | reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self): | def post(self): | ||||
| if not current_user.is_admin_or_owner: | |||||
| user = current_user | |||||
| if not user.is_admin_or_owner: | |||||
| raise Forbidden() | raise Forbidden() | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| reqparser = reqparse.RequestParser() | reqparser = reqparse.RequestParser() | ||||
| reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") | reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") | parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") | parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| return jsonable_encoder( | return jsonable_encoder( | ||||
| [ | [ | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| return jsonable_encoder( | return jsonable_encoder( | ||||
| [ | [ | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| user_id = current_user.id | |||||
| tenant_id = current_user.current_tenant_id | |||||
| user = current_user | |||||
| user_id = user.id | |||||
| tenant_id = user.current_tenant_id | |||||
| return jsonable_encoder( | return jsonable_encoder( | ||||
| [ | [ | ||||
| api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") | api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") | ||||
| # builtin tool provider | # builtin tool provider | ||||
| api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools") | |||||
| api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete") | |||||
| api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update") | |||||
| api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools") | |||||
| api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info") | |||||
| api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete") | |||||
| api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update") | |||||
| api.add_resource( | api.add_resource( | ||||
| ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials" | |||||
| ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials" | |||||
| ) | ) | ||||
| api.add_resource( | api.add_resource( | ||||
| ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema" | |||||
| ToolBuiltinProviderCredentialsSchemaApi, | |||||
| "/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema", | |||||
| ) | ) | ||||
| api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon") | |||||
| api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon") | |||||
| # api tool provider | # api tool provider | ||||
| api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") | api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") | 
| from configs import dify_config | from configs import dify_config | ||||
| from controllers.console.workspace.error import AccountNotInitializedError | from controllers.console.workspace.error import AccountNotInitializedError | ||||
| from extensions.ext_database import db | |||||
| from models.model import DifySetup | from models.model import DifySetup | ||||
| from services.feature_service import FeatureService, LicenseStatus | from services.feature_service import FeatureService, LicenseStatus | ||||
| from services.operation_service import OperationService | from services.operation_service import OperationService | ||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | def decorated(*args, **kwargs): | ||||
| # check setup | # check setup | ||||
| if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first(): | |||||
| if ( | |||||
| dify_config.EDITION == "SELF_HOSTED" | |||||
| and os.environ.get("INIT_PASSWORD") | |||||
| and not db.session.query(DifySetup).first() | |||||
| ): | |||||
| raise NotInitValidateError() | raise NotInitValidateError() | ||||
| elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first(): | |||||
| elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first(): | |||||
| raise NotSetupError() | raise NotSetupError() | ||||
| return view(*args, **kwargs) | return view(*args, **kwargs) | 
| api = ExternalApi(bp) | api = ExternalApi(bp) | ||||
| from . import image_preview, tool_files | |||||
| from . import image_preview, tool_files, upload | 
| from flask import request | |||||
| from flask_restful import Resource, marshal_with # type: ignore | |||||
| from werkzeug.exceptions import Forbidden | |||||
| import services | |||||
| from controllers.console.wraps import setup_required | |||||
| from controllers.files import api | |||||
| from controllers.files.error import UnsupportedFileTypeError | |||||
| from controllers.inner_api.plugin.wraps import get_user | |||||
| from controllers.service_api.app.error import FileTooLargeError | |||||
| from core.file.helpers import verify_plugin_file_signature | |||||
| from fields.file_fields import file_fields | |||||
| from services.file_service import FileService | |||||
| class PluginUploadFileApi(Resource): | |||||
| @setup_required | |||||
| @marshal_with(file_fields) | |||||
| def post(self): | |||||
| # get file from request | |||||
| file = request.files["file"] | |||||
| timestamp = request.args.get("timestamp") | |||||
| nonce = request.args.get("nonce") | |||||
| sign = request.args.get("sign") | |||||
| tenant_id = request.args.get("tenant_id") | |||||
| if not tenant_id: | |||||
| raise Forbidden("Invalid request.") | |||||
| user_id = request.args.get("user_id") | |||||
| user = get_user(tenant_id, user_id) | |||||
| filename = file.filename | |||||
| mimetype = file.mimetype | |||||
| if not filename or not mimetype: | |||||
| raise Forbidden("Invalid request.") | |||||
| if not timestamp or not nonce or not sign: | |||||
| raise Forbidden("Invalid request.") | |||||
| if not verify_plugin_file_signature( | |||||
| filename=filename, | |||||
| mimetype=mimetype, | |||||
| tenant_id=tenant_id, | |||||
| user_id=user_id, | |||||
| timestamp=timestamp, | |||||
| nonce=nonce, | |||||
| sign=sign, | |||||
| ): | |||||
| raise Forbidden("Invalid request.") | |||||
| try: | |||||
| upload_file = FileService.upload_file( | |||||
| filename=filename, | |||||
| content=file.read(), | |||||
| mimetype=mimetype, | |||||
| user=user, | |||||
| source=None, | |||||
| ) | |||||
| except services.errors.file.FileTooLargeError as file_too_large_error: | |||||
| raise FileTooLargeError(file_too_large_error.description) | |||||
| except services.errors.file.UnsupportedFileTypeError: | |||||
| raise UnsupportedFileTypeError() | |||||
| return upload_file, 201 | |||||
| api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin") | 
| bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") | bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") | ||||
| api = ExternalApi(bp) | api = ExternalApi(bp) | ||||
| from .plugin import plugin | |||||
| from .workspace import workspace | from .workspace import workspace | 
| from flask_restful import Resource # type: ignore | |||||
| from controllers.console.wraps import setup_required | |||||
| from controllers.inner_api import api | |||||
| from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data | |||||
| from controllers.inner_api.wraps import plugin_inner_api_only | |||||
| from core.file.helpers import get_signed_file_url_for_plugin | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||||
| from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation | |||||
| from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse | |||||
| from core.plugin.backwards_invocation.encrypt import PluginEncrypter | |||||
| from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation | |||||
| from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation | |||||
| from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation | |||||
| from core.plugin.entities.request import ( | |||||
| RequestInvokeApp, | |||||
| RequestInvokeEncrypt, | |||||
| RequestInvokeLLM, | |||||
| RequestInvokeModeration, | |||||
| RequestInvokeParameterExtractorNode, | |||||
| RequestInvokeQuestionClassifierNode, | |||||
| RequestInvokeRerank, | |||||
| RequestInvokeSpeech2Text, | |||||
| RequestInvokeSummary, | |||||
| RequestInvokeTextEmbedding, | |||||
| RequestInvokeTool, | |||||
| RequestInvokeTTS, | |||||
| RequestRequestUploadFile, | |||||
| ) | |||||
| from core.tools.entities.tool_entities import ToolProviderType | |||||
| from libs.helper import compact_generate_response | |||||
| from models.account import Account, Tenant | |||||
| from models.model import EndUser | |||||
| class PluginInvokeLLMApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeLLM) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM): | |||||
| def generator(): | |||||
| response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload) | |||||
| return PluginModelBackwardsInvocation.convert_to_event_stream(response) | |||||
| return compact_generate_response(generator()) | |||||
| class PluginInvokeTextEmbeddingApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeTextEmbedding) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding): | |||||
| try: | |||||
| return jsonable_encoder( | |||||
| BaseBackwardsInvocationResponse( | |||||
| data=PluginModelBackwardsInvocation.invoke_text_embedding( | |||||
| user_id=user_model.id, | |||||
| tenant=tenant_model, | |||||
| payload=payload, | |||||
| ) | |||||
| ) | |||||
| ) | |||||
| except Exception as e: | |||||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||||
| class PluginInvokeRerankApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeRerank) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank): | |||||
| try: | |||||
| return jsonable_encoder( | |||||
| BaseBackwardsInvocationResponse( | |||||
| data=PluginModelBackwardsInvocation.invoke_rerank( | |||||
| user_id=user_model.id, | |||||
| tenant=tenant_model, | |||||
| payload=payload, | |||||
| ) | |||||
| ) | |||||
| ) | |||||
| except Exception as e: | |||||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||||
| class PluginInvokeTTSApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeTTS) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS): | |||||
| def generator(): | |||||
| response = PluginModelBackwardsInvocation.invoke_tts( | |||||
| user_id=user_model.id, | |||||
| tenant=tenant_model, | |||||
| payload=payload, | |||||
| ) | |||||
| return PluginModelBackwardsInvocation.convert_to_event_stream(response) | |||||
| return compact_generate_response(generator()) | |||||
| class PluginInvokeSpeech2TextApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeSpeech2Text) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text): | |||||
| try: | |||||
| return jsonable_encoder( | |||||
| BaseBackwardsInvocationResponse( | |||||
| data=PluginModelBackwardsInvocation.invoke_speech2text( | |||||
| user_id=user_model.id, | |||||
| tenant=tenant_model, | |||||
| payload=payload, | |||||
| ) | |||||
| ) | |||||
| ) | |||||
| except Exception as e: | |||||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||||
| class PluginInvokeModerationApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeModeration) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration): | |||||
| try: | |||||
| return jsonable_encoder( | |||||
| BaseBackwardsInvocationResponse( | |||||
| data=PluginModelBackwardsInvocation.invoke_moderation( | |||||
| user_id=user_model.id, | |||||
| tenant=tenant_model, | |||||
| payload=payload, | |||||
| ) | |||||
| ) | |||||
| ) | |||||
| except Exception as e: | |||||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||||
| class PluginInvokeToolApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeTool) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool): | |||||
| def generator(): | |||||
| return PluginToolBackwardsInvocation.convert_to_event_stream( | |||||
| PluginToolBackwardsInvocation.invoke_tool( | |||||
| tenant_id=tenant_model.id, | |||||
| user_id=user_model.id, | |||||
| tool_type=ToolProviderType.value_of(payload.tool_type), | |||||
| provider=payload.provider, | |||||
| tool_name=payload.tool, | |||||
| tool_parameters=payload.tool_parameters, | |||||
| ), | |||||
| ) | |||||
| return compact_generate_response(generator()) | |||||
| class PluginInvokeParameterExtractorNodeApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeParameterExtractorNode) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode): | |||||
| try: | |||||
| return jsonable_encoder( | |||||
| BaseBackwardsInvocationResponse( | |||||
| data=PluginNodeBackwardsInvocation.invoke_parameter_extractor( | |||||
| tenant_id=tenant_model.id, | |||||
| user_id=user_model.id, | |||||
| parameters=payload.parameters, | |||||
| model_config=payload.model, | |||||
| instruction=payload.instruction, | |||||
| query=payload.query, | |||||
| ) | |||||
| ) | |||||
| ) | |||||
| except Exception as e: | |||||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||||
| class PluginInvokeQuestionClassifierNodeApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode): | |||||
| try: | |||||
| return jsonable_encoder( | |||||
| BaseBackwardsInvocationResponse( | |||||
| data=PluginNodeBackwardsInvocation.invoke_question_classifier( | |||||
| tenant_id=tenant_model.id, | |||||
| user_id=user_model.id, | |||||
| query=payload.query, | |||||
| model_config=payload.model, | |||||
| classes=payload.classes, | |||||
| instruction=payload.instruction, | |||||
| ) | |||||
| ) | |||||
| ) | |||||
| except Exception as e: | |||||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||||
| class PluginInvokeAppApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeApp) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp): | |||||
| response = PluginAppBackwardsInvocation.invoke_app( | |||||
| app_id=payload.app_id, | |||||
| user_id=user_model.id, | |||||
| tenant_id=tenant_model.id, | |||||
| conversation_id=payload.conversation_id, | |||||
| query=payload.query, | |||||
| stream=payload.response_mode == "streaming", | |||||
| inputs=payload.inputs, | |||||
| files=payload.files, | |||||
| ) | |||||
| return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response)) | |||||
| class PluginInvokeEncryptApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeEncrypt) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt): | |||||
| """ | |||||
| encrypt or decrypt data | |||||
| """ | |||||
| try: | |||||
| return BaseBackwardsInvocationResponse( | |||||
| data=PluginEncrypter.invoke_encrypt(tenant_model, payload) | |||||
| ).model_dump() | |||||
| except Exception as e: | |||||
| return BaseBackwardsInvocationResponse(error=str(e)).model_dump() | |||||
| class PluginInvokeSummaryApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestInvokeSummary) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary): | |||||
| try: | |||||
| return BaseBackwardsInvocationResponse( | |||||
| data={ | |||||
| "summary": PluginModelBackwardsInvocation.invoke_summary( | |||||
| user_id=user_model.id, | |||||
| tenant=tenant_model, | |||||
| payload=payload, | |||||
| ) | |||||
| } | |||||
| ).model_dump() | |||||
| except Exception as e: | |||||
| return BaseBackwardsInvocationResponse(error=str(e)).model_dump() | |||||
| class PluginUploadFileRequestApi(Resource): | |||||
| @setup_required | |||||
| @plugin_inner_api_only | |||||
| @get_user_tenant | |||||
| @plugin_data(payload_type=RequestRequestUploadFile) | |||||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile): | |||||
| # generate signed url | |||||
| url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id) | |||||
| return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() | |||||
| api.add_resource(PluginInvokeLLMApi, "/invoke/llm") | |||||
| api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") | |||||
| api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") | |||||
| api.add_resource(PluginInvokeTTSApi, "/invoke/tts") | |||||
| api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text") | |||||
| api.add_resource(PluginInvokeModerationApi, "/invoke/moderation") | |||||
| api.add_resource(PluginInvokeToolApi, "/invoke/tool") | |||||
| api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor") | |||||
| api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier") | |||||
| api.add_resource(PluginInvokeAppApi, "/invoke/app") | |||||
| api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") | |||||
| api.add_resource(PluginInvokeSummaryApi, "/invoke/summary") | |||||
| api.add_resource(PluginUploadFileRequestApi, "/upload/file/request") | 
| from collections.abc import Callable | |||||
| from functools import wraps | |||||
| from typing import Optional | |||||
| from flask import request | |||||
| from flask_restful import reqparse # type: ignore | |||||
| from pydantic import BaseModel | |||||
| from sqlalchemy.orm import Session | |||||
| from extensions.ext_database import db | |||||
| from models.account import Account, Tenant | |||||
| from models.model import EndUser | |||||
| from services.account_service import AccountService | |||||
| def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: | |||||
| try: | |||||
| with Session(db.engine) as session: | |||||
| if not user_id: | |||||
| user_id = "DEFAULT-USER" | |||||
| if user_id == "DEFAULT-USER": | |||||
| user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first() | |||||
| if not user_model: | |||||
| user_model = EndUser( | |||||
| tenant_id=tenant_id, | |||||
| type="service_api", | |||||
| is_anonymous=True if user_id == "DEFAULT-USER" else False, | |||||
| session_id=user_id, | |||||
| ) | |||||
| session.add(user_model) | |||||
| session.commit() | |||||
| else: | |||||
| user_model = AccountService.load_user(user_id) | |||||
| if not user_model: | |||||
| user_model = session.query(EndUser).filter(EndUser.id == user_id).first() | |||||
| if not user_model: | |||||
| raise ValueError("user not found") | |||||
| except Exception: | |||||
| raise ValueError("user not found") | |||||
| return user_model | |||||
| def get_user_tenant(view: Optional[Callable] = None): | |||||
| def decorator(view_func): | |||||
| @wraps(view_func) | |||||
| def decorated_view(*args, **kwargs): | |||||
| # fetch json body | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("tenant_id", type=str, required=True, location="json") | |||||
| parser.add_argument("user_id", type=str, required=True, location="json") | |||||
| kwargs = parser.parse_args() | |||||
| user_id = kwargs.get("user_id") | |||||
| tenant_id = kwargs.get("tenant_id") | |||||
| if not tenant_id: | |||||
| raise ValueError("tenant_id is required") | |||||
| if not user_id: | |||||
| user_id = "DEFAULT-USER" | |||||
| del kwargs["tenant_id"] | |||||
| del kwargs["user_id"] | |||||
| try: | |||||
| tenant_model = ( | |||||
| db.session.query(Tenant) | |||||
| .filter( | |||||
| Tenant.id == tenant_id, | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| except Exception: | |||||
| raise ValueError("tenant not found") | |||||
| if not tenant_model: | |||||
| raise ValueError("tenant not found") | |||||
| kwargs["tenant_model"] = tenant_model | |||||
| kwargs["user_model"] = get_user(tenant_id, user_id) | |||||
| return view_func(*args, **kwargs) | |||||
| return decorated_view | |||||
| if view is None: | |||||
| return decorator | |||||
| else: | |||||
| return decorator(view) | |||||
| def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]): | |||||
| def decorator(view_func): | |||||
| def decorated_view(*args, **kwargs): | |||||
| try: | |||||
| data = request.get_json() | |||||
| except Exception: | |||||
| raise ValueError("invalid json") | |||||
| try: | |||||
| payload = payload_type(**data) | |||||
| except Exception as e: | |||||
| raise ValueError(f"invalid payload: {str(e)}") | |||||
| kwargs["payload"] = payload | |||||
| return view_func(*args, **kwargs) | |||||
| return decorated_view | |||||
| if view is None: | |||||
| return decorator | |||||
| else: | |||||
| return decorator(view) | 
| from controllers.console.wraps import setup_required | from controllers.console.wraps import setup_required | ||||
| from controllers.inner_api import api | from controllers.inner_api import api | ||||
| from controllers.inner_api.wraps import inner_api_only | |||||
| from controllers.inner_api.wraps import enterprise_inner_api_only | |||||
| from events.tenant_event import tenant_was_created | from events.tenant_event import tenant_was_created | ||||
| from models.account import Account | from models.account import Account | ||||
| from services.account_service import TenantService | from services.account_service import TenantService | ||||
| class EnterpriseWorkspace(Resource): | class EnterpriseWorkspace(Resource): | ||||
| @setup_required | @setup_required | ||||
| @inner_api_only | |||||
| @enterprise_inner_api_only | |||||
| def post(self): | def post(self): | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("name", type=str, required=True, location="json") | parser.add_argument("name", type=str, required=True, location="json") | ||||
| class EnterpriseWorkspaceNoOwnerEmail(Resource): | class EnterpriseWorkspaceNoOwnerEmail(Resource): | ||||
| @setup_required | @setup_required | ||||
| @inner_api_only | |||||
| @enterprise_inner_api_only | |||||
| def post(self): | def post(self): | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("name", type=str, required=True, location="json") | parser.add_argument("name", type=str, required=True, location="json") | 
| from models.model import EndUser | from models.model import EndUser | ||||
| def inner_api_only(view): | |||||
| def enterprise_inner_api_only(view): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | def decorated(*args, **kwargs): | ||||
| if not dify_config.INNER_API: | if not dify_config.INNER_API: | ||||
| # get header 'X-Inner-Api-Key' | # get header 'X-Inner-Api-Key' | ||||
| inner_api_key = request.headers.get("X-Inner-Api-Key") | inner_api_key = request.headers.get("X-Inner-Api-Key") | ||||
| if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: | |||||
| if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN: | |||||
| abort(401) | abort(401) | ||||
| return view(*args, **kwargs) | return view(*args, **kwargs) | ||||
| return decorated | return decorated | ||||
| def inner_api_user_auth(view): | |||||
| def enterprise_inner_api_user_auth(view): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | def decorated(*args, **kwargs): | ||||
| if not dify_config.INNER_API: | if not dify_config.INNER_API: | ||||
| return view(*args, **kwargs) | return view(*args, **kwargs) | ||||
| return decorated | return decorated | ||||
| def plugin_inner_api_only(view): | |||||
| @wraps(view) | |||||
| def decorated(*args, **kwargs): | |||||
| if not dify_config.PLUGIN_DAEMON_KEY: | |||||
| abort(404) | |||||
| # get header 'X-Inner-Api-Key' | |||||
| inner_api_key = request.headers.get("X-Inner-Api-Key") | |||||
| if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN: | |||||
| abort(404) | |||||
| return view(*args, **kwargs) | |||||
| return decorated | 
| import json | import json | ||||
| import logging | import logging | ||||
| import uuid | import uuid | ||||
| from datetime import UTC, datetime | |||||
| from typing import Optional, Union, cast | from typing import Optional, Union, cast | ||||
| from core.agent.entities import AgentEntity, AgentToolEntity | from core.agent.entities import AgentEntity, AgentToolEntity | ||||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | from core.model_runtime.entities.message_entities import ImagePromptMessageContent | ||||
| from core.model_runtime.entities.model_entities import ModelFeature | from core.model_runtime.entities.model_entities import ModelFeature | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||||
| from core.prompt.utils.extract_thread_messages import extract_thread_messages | from core.prompt.utils.extract_thread_messages import extract_thread_messages | ||||
| from core.tools.__base.tool import Tool | |||||
| from core.tools.entities.tool_entities import ( | from core.tools.entities.tool_entities import ( | ||||
| ToolParameter, | ToolParameter, | ||||
| ToolRuntimeVariablePool, | |||||
| ) | ) | ||||
| from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool | |||||
| from core.tools.tool.tool import Tool | |||||
| from core.tools.tool_manager import ToolManager | from core.tools.tool_manager import ToolManager | ||||
| from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| from models.model import Conversation, Message, MessageAgentThought, MessageFile | from models.model import Conversation, Message, MessageAgentThought, MessageFile | ||||
| from models.tools import ToolConversationVariables | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| message: Message, | message: Message, | ||||
| user_id: str, | user_id: str, | ||||
| model_instance: ModelInstance, | |||||
| memory: Optional[TokenBufferMemory] = None, | memory: Optional[TokenBufferMemory] = None, | ||||
| prompt_messages: Optional[list[PromptMessage]] = None, | prompt_messages: Optional[list[PromptMessage]] = None, | ||||
| variables_pool: Optional[ToolRuntimeVariablePool] = None, | |||||
| db_variables: Optional[ToolConversationVariables] = None, | |||||
| model_instance: ModelInstance, | |||||
| ) -> None: | ) -> None: | ||||
| self.tenant_id = tenant_id | self.tenant_id = tenant_id | ||||
| self.application_generate_entity = application_generate_entity | self.application_generate_entity = application_generate_entity | ||||
| self.user_id = user_id | self.user_id = user_id | ||||
| self.memory = memory | self.memory = memory | ||||
| self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) | self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) | ||||
| self.variables_pool = variables_pool | |||||
| self.db_variables_pool = db_variables | |||||
| self.model_instance = model_instance | self.model_instance = model_instance | ||||
| # init callback | # init callback | ||||
| agent_tool=tool, | agent_tool=tool, | ||||
| invoke_from=self.application_generate_entity.invoke_from, | invoke_from=self.application_generate_entity.invoke_from, | ||||
| ) | ) | ||||
| tool_entity.load_variables(self.variables_pool) | |||||
| assert tool_entity.entity.description | |||||
| message_tool = PromptMessageTool( | message_tool = PromptMessageTool( | ||||
| name=tool.tool_name, | name=tool.tool_name, | ||||
| description=tool_entity.description.llm if tool_entity.description else "", | |||||
| description=tool_entity.entity.description.llm, | |||||
| parameters={ | parameters={ | ||||
| "type": "object", | "type": "object", | ||||
| "properties": {}, | "properties": {}, | ||||
| }, | }, | ||||
| ) | ) | ||||
| parameters = tool_entity.get_all_runtime_parameters() | |||||
| parameters = tool_entity.get_merged_runtime_parameters() | |||||
| for parameter in parameters: | for parameter in parameters: | ||||
| if parameter.form != ToolParameter.ToolParameterForm.LLM: | if parameter.form != ToolParameter.ToolParameterForm.LLM: | ||||
| continue | continue | ||||
| """ | """ | ||||
| convert dataset retriever tool to prompt message tool | convert dataset retriever tool to prompt message tool | ||||
| """ | """ | ||||
| assert tool.entity.description | |||||
| prompt_tool = PromptMessageTool( | prompt_tool = PromptMessageTool( | ||||
| name=tool.identity.name if tool.identity else "unknown", | |||||
| description=tool.description.llm if tool.description else "", | |||||
| name=tool.entity.identity.name, | |||||
| description=tool.entity.description.llm, | |||||
| parameters={ | parameters={ | ||||
| "type": "object", | "type": "object", | ||||
| "properties": {}, | "properties": {}, | ||||
| # save prompt tool | # save prompt tool | ||||
| prompt_messages_tools.append(prompt_tool) | prompt_messages_tools.append(prompt_tool) | ||||
| # save tool entity | # save tool entity | ||||
| if dataset_tool.identity is not None: | |||||
| tool_instances[dataset_tool.identity.name] = dataset_tool | |||||
| tool_instances[dataset_tool.entity.identity.name] = dataset_tool | |||||
| return tool_instances, prompt_messages_tools | return tool_instances, prompt_messages_tools | ||||
| def save_agent_thought( | def save_agent_thought( | ||||
| self, | self, | ||||
| agent_thought: MessageAgentThought, | agent_thought: MessageAgentThought, | ||||
| tool_name: str, | |||||
| tool_input: Union[str, dict], | |||||
| thought: str, | |||||
| tool_name: str | None, | |||||
| tool_input: Union[str, dict, None], | |||||
| thought: str | None, | |||||
| observation: Union[str, dict, None], | observation: Union[str, dict, None], | ||||
| tool_invoke_meta: Union[str, dict, None], | tool_invoke_meta: Union[str, dict, None], | ||||
| answer: str, | |||||
| answer: str | None, | |||||
| messages_ids: list[str], | messages_ids: list[str], | ||||
| llm_usage: LLMUsage | None = None, | llm_usage: LLMUsage | None = None, | ||||
| ): | ): | ||||
| """ | """ | ||||
| Save agent thought | Save agent thought | ||||
| """ | """ | ||||
| queried_thought = ( | |||||
| updated_agent_thought = ( | |||||
| db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() | db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() | ||||
| ) | ) | ||||
| if not queried_thought: | |||||
| raise ValueError(f"Agent thought {agent_thought.id} not found") | |||||
| agent_thought = queried_thought | |||||
| if not updated_agent_thought: | |||||
| raise ValueError("agent thought not found") | |||||
| agent_thought = updated_agent_thought | |||||
| if thought: | if thought: | ||||
| agent_thought.thought = thought | agent_thought.thought = thought | ||||
| if isinstance(tool_input, dict): | if isinstance(tool_input, dict): | ||||
| try: | try: | ||||
| tool_input = json.dumps(tool_input, ensure_ascii=False) | tool_input = json.dumps(tool_input, ensure_ascii=False) | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| tool_input = json.dumps(tool_input) | tool_input = json.dumps(tool_input) | ||||
| agent_thought.tool_input = tool_input | |||||
| updated_agent_thought.tool_input = tool_input | |||||
| if observation: | if observation: | ||||
| if isinstance(observation, dict): | if isinstance(observation, dict): | ||||
| try: | try: | ||||
| observation = json.dumps(observation, ensure_ascii=False) | observation = json.dumps(observation, ensure_ascii=False) | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| observation = json.dumps(observation) | observation = json.dumps(observation) | ||||
| agent_thought.observation = observation | |||||
| updated_agent_thought.observation = observation | |||||
| if answer: | if answer: | ||||
| agent_thought.answer = answer | agent_thought.answer = answer | ||||
| if messages_ids is not None and len(messages_ids) > 0: | if messages_ids is not None and len(messages_ids) > 0: | ||||
| agent_thought.message_files = json.dumps(messages_ids) | |||||
| updated_agent_thought.message_files = json.dumps(messages_ids) | |||||
| if llm_usage: | if llm_usage: | ||||
| agent_thought.message_token = llm_usage.prompt_tokens | |||||
| agent_thought.message_price_unit = llm_usage.prompt_price_unit | |||||
| agent_thought.message_unit_price = llm_usage.prompt_unit_price | |||||
| agent_thought.answer_token = llm_usage.completion_tokens | |||||
| agent_thought.answer_price_unit = llm_usage.completion_price_unit | |||||
| agent_thought.answer_unit_price = llm_usage.completion_unit_price | |||||
| agent_thought.tokens = llm_usage.total_tokens | |||||
| agent_thought.total_price = llm_usage.total_price | |||||
| updated_agent_thought.message_token = llm_usage.prompt_tokens | |||||
| updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit | |||||
| updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price | |||||
| updated_agent_thought.answer_token = llm_usage.completion_tokens | |||||
| updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit | |||||
| updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price | |||||
| updated_agent_thought.tokens = llm_usage.total_tokens | |||||
| updated_agent_thought.total_price = llm_usage.total_price | |||||
| # check if tool labels is not empty | # check if tool labels is not empty | ||||
| labels = agent_thought.tool_labels or {} | |||||
| tools = agent_thought.tool.split(";") if agent_thought.tool else [] | |||||
| labels = updated_agent_thought.tool_labels or {} | |||||
| tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else [] | |||||
| for tool in tools: | for tool in tools: | ||||
| if not tool: | if not tool: | ||||
| continue | continue | ||||
| else: | else: | ||||
| labels[tool] = {"en_US": tool, "zh_Hans": tool} | labels[tool] = {"en_US": tool, "zh_Hans": tool} | ||||
| agent_thought.tool_labels_str = json.dumps(labels) | |||||
| updated_agent_thought.tool_labels_str = json.dumps(labels) | |||||
| if tool_invoke_meta is not None: | if tool_invoke_meta is not None: | ||||
| if isinstance(tool_invoke_meta, dict): | if isinstance(tool_invoke_meta, dict): | ||||
| try: | try: | ||||
| tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) | tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| tool_invoke_meta = json.dumps(tool_invoke_meta) | tool_invoke_meta = json.dumps(tool_invoke_meta) | ||||
| agent_thought.tool_meta_str = tool_invoke_meta | |||||
| db.session.commit() | |||||
| db.session.close() | |||||
| def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): | |||||
| """ | |||||
| convert tool variables to db variables | |||||
| """ | |||||
| queried_variables = ( | |||||
| db.session.query(ToolConversationVariables) | |||||
| .filter( | |||||
| ToolConversationVariables.conversation_id == self.message.conversation_id, | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| if not queried_variables: | |||||
| return | |||||
| updated_agent_thought.tool_meta_str = tool_invoke_meta | |||||
| db_variables = queried_variables | |||||
| db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) | |||||
| db.session.commit() | db.session.commit() | ||||
| db.session.close() | db.session.close() | ||||
| tool_call_response: list[ToolPromptMessage] = [] | tool_call_response: list[ToolPromptMessage] = [] | ||||
| try: | try: | ||||
| tool_inputs = json.loads(agent_thought.tool_input) | tool_inputs = json.loads(agent_thought.tool_input) | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| tool_inputs = {tool: {} for tool in tools} | tool_inputs = {tool: {} for tool in tools} | ||||
| try: | try: | ||||
| tool_responses = json.loads(agent_thought.observation) | tool_responses = json.loads(agent_thought.observation) | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| tool_responses = dict.fromkeys(tools, agent_thought.observation) | tool_responses = dict.fromkeys(tools, agent_thought.observation) | ||||
| for tool in tools: | for tool in tools: | ||||
| files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() | files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() | ||||
| if not files: | if not files: | ||||
| return UserPromptMessage(content=message.query) | return UserPromptMessage(content=message.query) | ||||
| file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) | |||||
| if message.app_model_config: | |||||
| file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) | |||||
| else: | |||||
| file_extra_config = None | |||||
| if not file_extra_config: | if not file_extra_config: | ||||
| return UserPromptMessage(content=message.query) | return UserPromptMessage(content=message.query) | ||||
| import json | import json | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from collections.abc import Generator, Mapping | |||||
| from collections.abc import Generator, Mapping, Sequence | |||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from core.agent.base_agent_runner import BaseAgentRunner | from core.agent.base_agent_runner import BaseAgentRunner | ||||
| ) | ) | ||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | ||||
| from core.tools.__base.tool import Tool | |||||
| from core.tools.entities.tool_entities import ToolInvokeMeta | from core.tools.entities.tool_entities import ToolInvokeMeta | ||||
| from core.tools.tool.tool import Tool | |||||
| from core.tools.tool_engine import ToolEngine | from core.tools.tool_engine import ToolEngine | ||||
| from models.model import Message | from models.model import Message | ||||
| class CotAgentRunner(BaseAgentRunner, ABC): | class CotAgentRunner(BaseAgentRunner, ABC): | ||||
| _is_first_iteration = True | _is_first_iteration = True | ||||
| _ignore_observation_providers = ["wenxin"] | _ignore_observation_providers = ["wenxin"] | ||||
| _historic_prompt_messages: list[PromptMessage] | None = None | |||||
| _agent_scratchpad: list[AgentScratchpadUnit] | None = None | |||||
| _instruction: str = "" # FIXME this must be str for now | |||||
| _query: str | None = None | |||||
| _prompt_messages_tools: list[PromptMessageTool] = [] | |||||
| _historic_prompt_messages: list[PromptMessage] | |||||
| _agent_scratchpad: list[AgentScratchpadUnit] | |||||
| _instruction: str | |||||
| _query: str | |||||
| _prompt_messages_tools: Sequence[PromptMessageTool] | |||||
| def run( | def run( | ||||
| self, | self, | ||||
| """ | """ | ||||
| Run Cot agent application | Run Cot agent application | ||||
| """ | """ | ||||
| app_generate_entity = self.application_generate_entity | app_generate_entity = self.application_generate_entity | ||||
| self._repack_app_generate_entity(app_generate_entity) | self._repack_app_generate_entity(app_generate_entity) | ||||
| self._init_react_state(query) | self._init_react_state(query) | ||||
| app_generate_entity.model_conf.stop.append("Observation") | app_generate_entity.model_conf.stop.append("Observation") | ||||
| app_config = self.app_config | app_config = self.app_config | ||||
| assert app_config.agent | |||||
| # init instruction | # init instruction | ||||
| inputs = inputs or {} | inputs = inputs or {} | ||||
| instruction = app_config.prompt_template.simple_prompt_template | |||||
| self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs) | |||||
| instruction = app_config.prompt_template.simple_prompt_template or "" | |||||
| self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) | |||||
| iteration_step = 1 | iteration_step = 1 | ||||
| max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 | max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 | ||||
| # convert tools into ModelRuntime Tool format | # convert tools into ModelRuntime Tool format | ||||
| tool_instances, self._prompt_messages_tools = self._init_prompt_tools() | |||||
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | |||||
| self._prompt_messages_tools = prompt_messages_tools | |||||
| function_call_state = True | function_call_state = True | ||||
| llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} | llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} | ||||
| callbacks=[], | callbacks=[], | ||||
| ) | ) | ||||
| if not isinstance(chunks, Generator): | |||||
| raise ValueError("Expected streaming response from LLM") | |||||
| # check llm result | |||||
| if not chunks: | |||||
| raise ValueError("failed to invoke llm") | |||||
| usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None} | |||||
| usage_dict: dict[str, Optional[LLMUsage]] = {} | |||||
| react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) | react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) | ||||
| scratchpad = AgentScratchpadUnit( | scratchpad = AgentScratchpadUnit( | ||||
| agent_response="", | agent_response="", | ||||
| if isinstance(chunk, AgentScratchpadUnit.Action): | if isinstance(chunk, AgentScratchpadUnit.Action): | ||||
| action = chunk | action = chunk | ||||
| # detect action | # detect action | ||||
| if scratchpad.agent_response is not None: | |||||
| scratchpad.agent_response += json.dumps(chunk.model_dump()) | |||||
| assert scratchpad.agent_response is not None | |||||
| scratchpad.agent_response += json.dumps(chunk.model_dump()) | |||||
| scratchpad.action_str = json.dumps(chunk.model_dump()) | scratchpad.action_str = json.dumps(chunk.model_dump()) | ||||
| scratchpad.action = action | scratchpad.action = action | ||||
| else: | else: | ||||
| if scratchpad.agent_response is not None: | |||||
| scratchpad.agent_response += chunk | |||||
| if scratchpad.thought is not None: | |||||
| scratchpad.thought += chunk | |||||
| assert scratchpad.agent_response is not None | |||||
| scratchpad.agent_response += chunk | |||||
| assert scratchpad.thought is not None | |||||
| scratchpad.thought += chunk | |||||
| yield LLMResultChunk( | yield LLMResultChunk( | ||||
| model=self.model_config.model, | model=self.model_config.model, | ||||
| prompt_messages=prompt_messages, | prompt_messages=prompt_messages, | ||||
| system_fingerprint="", | system_fingerprint="", | ||||
| delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), | delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), | ||||
| ) | ) | ||||
| if scratchpad.thought is not None: | |||||
| scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" | |||||
| if self._agent_scratchpad is not None: | |||||
| self._agent_scratchpad.append(scratchpad) | |||||
| assert scratchpad.thought is not None | |||||
| scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" | |||||
| self._agent_scratchpad.append(scratchpad) | |||||
| # get llm usage | # get llm usage | ||||
| if "usage" in usage_dict: | if "usage" in usage_dict: | ||||
| answer=final_answer, | answer=final_answer, | ||||
| messages_ids=[], | messages_ids=[], | ||||
| ) | ) | ||||
| if self.variables_pool is not None and self.db_variables_pool is not None: | |||||
| self.update_db_variables(self.variables_pool, self.db_variables_pool) | |||||
| # publish end event | # publish end event | ||||
| self.queue_manager.publish( | self.queue_manager.publish( | ||||
| QueueMessageEndEvent( | QueueMessageEndEvent( | ||||
| def _handle_invoke_action( | def _handle_invoke_action( | ||||
| self, | self, | ||||
| action: AgentScratchpadUnit.Action, | action: AgentScratchpadUnit.Action, | ||||
| tool_instances: dict[str, Tool], | |||||
| tool_instances: Mapping[str, Tool], | |||||
| message_file_ids: list[str], | message_file_ids: list[str], | ||||
| trace_manager: Optional[TraceQueueManager] = None, | trace_manager: Optional[TraceQueueManager] = None, | ||||
| ) -> tuple[str, ToolInvokeMeta]: | ) -> tuple[str, ToolInvokeMeta]: | ||||
| ) | ) | ||||
| # publish files | # publish files | ||||
| for message_file_id, save_as in message_files: | |||||
| if save_as is not None and self.variables_pool: | |||||
| # FIXME the save_as type is confusing, it should be a string or not | |||||
| self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as)) | |||||
| for message_file_id in message_files: | |||||
| # publish message file | # publish message file | ||||
| self.queue_manager.publish( | self.queue_manager.publish( | ||||
| QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER | QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER | ||||
| for key, value in inputs.items(): | for key, value in inputs.items(): | ||||
| try: | try: | ||||
| instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) | instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| continue | continue | ||||
| return instruction | return instruction | ||||
| return message | return message | ||||
| def _organize_historic_prompt_messages( | def _organize_historic_prompt_messages( | ||||
| self, current_session_messages: Optional[list[PromptMessage]] = None | |||||
| self, current_session_messages: list[PromptMessage] | None = None | |||||
| ) -> list[PromptMessage]: | ) -> list[PromptMessage]: | ||||
| """ | """ | ||||
| organize historic prompt messages | organize historic prompt messages | ||||
| for message in self.history_prompt_messages: | for message in self.history_prompt_messages: | ||||
| if isinstance(message, AssistantPromptMessage): | if isinstance(message, AssistantPromptMessage): | ||||
| if not current_scratchpad: | if not current_scratchpad: | ||||
| if not isinstance(message.content, str | None): | |||||
| raise NotImplementedError("expected str type") | |||||
| assert isinstance(message.content, str) | |||||
| current_scratchpad = AgentScratchpadUnit( | current_scratchpad = AgentScratchpadUnit( | ||||
| agent_response=message.content, | agent_response=message.content, | ||||
| thought=message.content or "I am thinking about how to help you", | thought=message.content or "I am thinking about how to help you", | ||||
| except: | except: | ||||
| pass | pass | ||||
| elif isinstance(message, ToolPromptMessage): | elif isinstance(message, ToolPromptMessage): | ||||
| if not current_scratchpad: | |||||
| continue | |||||
| if isinstance(message.content, str): | |||||
| if current_scratchpad: | |||||
| assert isinstance(message.content, str) | |||||
| current_scratchpad.observation = message.content | current_scratchpad.observation = message.content | ||||
| else: | else: | ||||
| raise NotImplementedError("expected str type") | raise NotImplementedError("expected str type") | 
| """ | """ | ||||
| Organize system prompt | Organize system prompt | ||||
| """ | """ | ||||
| if not self.app_config.agent: | |||||
| raise ValueError("Agent configuration is not set") | |||||
| assert self.app_config.agent | |||||
| assert self.app_config.agent.prompt | |||||
| prompt_entity = self.app_config.agent.prompt | prompt_entity = self.app_config.agent.prompt | ||||
| if not prompt_entity: | if not prompt_entity: | ||||
| assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str | assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str | ||||
| for unit in agent_scratchpad: | for unit in agent_scratchpad: | ||||
| if unit.is_final(): | if unit.is_final(): | ||||
| assert isinstance(assistant_message.content, str) | |||||
| assistant_message.content += f"Final Answer: {unit.agent_response}" | assistant_message.content += f"Final Answer: {unit.agent_response}" | ||||
| else: | else: | ||||
| assert isinstance(assistant_message.content, str) | |||||
| assistant_message.content += f"Thought: {unit.thought}\n\n" | assistant_message.content += f"Thought: {unit.thought}\n\n" | ||||
| if unit.action_str: | if unit.action_str: | ||||
| assistant_message.content += f"Action: {unit.action_str}\n\n" | assistant_message.content += f"Action: {unit.action_str}\n\n" | 
| from enum import Enum | |||||
| from typing import Any, Literal, Optional, Union | |||||
| from enum import StrEnum | |||||
| from typing import Any, Optional, Union | |||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType | |||||
| class AgentToolEntity(BaseModel): | class AgentToolEntity(BaseModel): | ||||
| """ | """ | ||||
| Agent Tool Entity. | Agent Tool Entity. | ||||
| """ | """ | ||||
| provider_type: Literal["builtin", "api", "workflow"] | |||||
| provider_type: ToolProviderType | |||||
| provider_id: str | provider_id: str | ||||
| tool_name: str | tool_name: str | ||||
| tool_parameters: dict[str, Any] = {} | tool_parameters: dict[str, Any] = {} | ||||
| plugin_unique_identifier: str | None = None | |||||
| class AgentPromptEntity(BaseModel): | class AgentPromptEntity(BaseModel): | ||||
| Agent Entity. | Agent Entity. | ||||
| """ | """ | ||||
| class Strategy(Enum): | |||||
| class Strategy(StrEnum): | |||||
| """ | """ | ||||
| Agent Strategy. | Agent Strategy. | ||||
| """ | """ | ||||
| model: str | model: str | ||||
| strategy: Strategy | strategy: Strategy | ||||
| prompt: Optional[AgentPromptEntity] = None | prompt: Optional[AgentPromptEntity] = None | ||||
| tools: list[AgentToolEntity] | None = None | |||||
| tools: Optional[list[AgentToolEntity]] = None | |||||
| max_iteration: int = 5 | max_iteration: int = 5 | ||||
| class AgentInvokeMessage(ToolInvokeMessage): | |||||
| """ | |||||
| Agent Invoke Message. | |||||
| """ | |||||
| pass | 
| # convert tools into ModelRuntime Tool format | # convert tools into ModelRuntime Tool format | ||||
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | tool_instances, prompt_messages_tools = self._init_prompt_tools() | ||||
| assert app_config.agent | |||||
| iteration_step = 1 | iteration_step = 1 | ||||
| max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 | max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 | ||||
| # continue to run until there is not any tool call | # continue to run until there is not any tool call | ||||
| function_call_state = True | function_call_state = True | ||||
| llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()} | |||||
| llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} | |||||
| final_answer = "" | final_answer = "" | ||||
| # get tracing instance | # get tracing instance | ||||
| trace_manager = app_generate_entity.trace_manager | trace_manager = app_generate_entity.trace_manager | ||||
| def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): | |||||
| def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): | |||||
| if not final_llm_usage_dict["usage"]: | if not final_llm_usage_dict["usage"]: | ||||
| final_llm_usage_dict["usage"] = usage | final_llm_usage_dict["usage"] = usage | ||||
| else: | else: | ||||
| current_llm_usage = None | current_llm_usage = None | ||||
| if self.stream_tool_call and isinstance(chunks, Generator): | |||||
| if isinstance(chunks, Generator): | |||||
| is_first_chunk = True | is_first_chunk = True | ||||
| for chunk in chunks: | for chunk in chunks: | ||||
| if is_first_chunk: | if is_first_chunk: | ||||
| tool_call_inputs = json.dumps( | tool_call_inputs = json.dumps( | ||||
| {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False | {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False | ||||
| ) | ) | ||||
| except json.JSONDecodeError as e: | |||||
| except json.JSONDecodeError: | |||||
| # ensure ascii to avoid encoding error | # ensure ascii to avoid encoding error | ||||
| tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) | tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) | ||||
| current_llm_usage = chunk.delta.usage | current_llm_usage = chunk.delta.usage | ||||
| yield chunk | yield chunk | ||||
| elif not self.stream_tool_call and isinstance(chunks, LLMResult): | |||||
| else: | |||||
| result = chunks | result = chunks | ||||
| # check if there is any tool call | # check if there is any tool call | ||||
| if self.check_blocking_tool_calls(result): | if self.check_blocking_tool_calls(result): | ||||
| tool_call_inputs = json.dumps( | tool_call_inputs = json.dumps( | ||||
| {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False | {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False | ||||
| ) | ) | ||||
| except json.JSONDecodeError as e: | |||||
| except json.JSONDecodeError: | |||||
| # ensure ascii to avoid encoding error | # ensure ascii to avoid encoding error | ||||
| tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) | tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) | ||||
| usage=result.usage, | usage=result.usage, | ||||
| ), | ), | ||||
| ) | ) | ||||
| else: | |||||
| raise RuntimeError(f"invalid chunks type: {type(chunks)}") | |||||
| assistant_message = AssistantPromptMessage(content="", tool_calls=[]) | assistant_message = AssistantPromptMessage(content="", tool_calls=[]) | ||||
| if tool_calls: | if tool_calls: | ||||
| invoke_from=self.application_generate_entity.invoke_from, | invoke_from=self.application_generate_entity.invoke_from, | ||||
| agent_tool_callback=self.agent_callback, | agent_tool_callback=self.agent_callback, | ||||
| trace_manager=trace_manager, | trace_manager=trace_manager, | ||||
| app_id=self.application_generate_entity.app_config.app_id, | |||||
| message_id=self.message.id, | |||||
| conversation_id=self.conversation.id, | |||||
| ) | ) | ||||
| # publish files | # publish files | ||||
| for message_file_id, save_as in message_files: | |||||
| if save_as: | |||||
| if self.variables_pool: | |||||
| self.variables_pool.set_file( | |||||
| tool_name=tool_call_name, value=message_file_id, name=save_as | |||||
| ) | |||||
| for message_file_id in message_files: | |||||
| # publish message file | # publish message file | ||||
| self.queue_manager.publish( | self.queue_manager.publish( | ||||
| QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER | QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER | ||||
| iteration_step += 1 | iteration_step += 1 | ||||
| if self.variables_pool and self.db_variables_pool: | |||||
| self.update_db_variables(self.variables_pool, self.db_variables_pool) | |||||
| # publish end event | # publish end event | ||||
| self.queue_manager.publish( | self.queue_manager.publish( | ||||
| QueueMessageEndEvent( | QueueMessageEndEvent( | ||||
| return True | return True | ||||
| return False | return False | ||||
| def extract_tool_calls( | |||||
| self, llm_result_chunk: LLMResultChunk | |||||
| ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: | |||||
| def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]: | |||||
| """ | """ | ||||
| Extract tool calls from llm result chunk | Extract tool calls from llm result chunk | ||||
| return tool_calls | return tool_calls | ||||
| def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: | |||||
| def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]: | |||||
| """ | """ | ||||
| Extract blocking tool calls from llm result | Extract blocking tool calls from llm result | ||||
| return tool_calls | return tool_calls | ||||
| def _init_system_message( | |||||
| self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None | |||||
| ) -> list[PromptMessage]: | |||||
| def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | |||||
| """ | """ | ||||
| Initialize system message | Initialize system message | ||||
| """ | """ | 
| import enum | |||||
| from typing import Any, Optional | |||||
| from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator | |||||
| from core.entities.parameter_entities import CommonParameterType | |||||
| from core.plugin.entities.parameters import ( | |||||
| PluginParameter, | |||||
| as_normal_type, | |||||
| cast_parameter_value, | |||||
| init_frontend_parameter, | |||||
| ) | |||||
| from core.tools.entities.common_entities import I18nObject | |||||
| from core.tools.entities.tool_entities import ( | |||||
| ToolIdentity, | |||||
| ToolProviderIdentity, | |||||
| ) | |||||
| class AgentStrategyProviderIdentity(ToolProviderIdentity): | |||||
| """ | |||||
| Inherits from ToolProviderIdentity, without any additional fields. | |||||
| """ | |||||
| pass | |||||
| class AgentStrategyParameter(PluginParameter): | |||||
| class AgentStrategyParameterType(enum.StrEnum): | |||||
| """ | |||||
| Keep all the types from PluginParameterType | |||||
| """ | |||||
| STRING = CommonParameterType.STRING.value | |||||
| NUMBER = CommonParameterType.NUMBER.value | |||||
| BOOLEAN = CommonParameterType.BOOLEAN.value | |||||
| SELECT = CommonParameterType.SELECT.value | |||||
| SECRET_INPUT = CommonParameterType.SECRET_INPUT.value | |||||
| FILE = CommonParameterType.FILE.value | |||||
| FILES = CommonParameterType.FILES.value | |||||
| APP_SELECTOR = CommonParameterType.APP_SELECTOR.value | |||||
| MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value | |||||
| TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value | |||||
| # deprecated, should not use. | |||||
| SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value | |||||
| def as_normal_type(self): | |||||
| return as_normal_type(self) | |||||
| def cast_value(self, value: Any): | |||||
| return cast_parameter_value(self, value) | |||||
| type: AgentStrategyParameterType = Field(..., description="The type of the parameter") | |||||
| def init_frontend_parameter(self, value: Any): | |||||
| return init_frontend_parameter(self, self.type, value) | |||||
| class AgentStrategyProviderEntity(BaseModel): | |||||
| identity: AgentStrategyProviderIdentity | |||||
| plugin_id: Optional[str] = Field(None, description="The id of the plugin") | |||||
| class AgentStrategyIdentity(ToolIdentity): | |||||
| """ | |||||
| Inherits from ToolIdentity, without any additional fields. | |||||
| """ | |||||
| pass | |||||
| class AgentStrategyEntity(BaseModel): | |||||
| identity: AgentStrategyIdentity | |||||
| parameters: list[AgentStrategyParameter] = Field(default_factory=list) | |||||
| description: I18nObject = Field(..., description="The description of the agent strategy") | |||||
| output_schema: Optional[dict] = None | |||||
| # pydantic configs | |||||
| model_config = ConfigDict(protected_namespaces=()) | |||||
| @field_validator("parameters", mode="before") | |||||
| @classmethod | |||||
| def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]: | |||||
| return v or [] | |||||
| class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity): | |||||
| strategies: list[AgentStrategyEntity] = Field(default_factory=list) | 
| from abc import ABC, abstractmethod | |||||
| from collections.abc import Generator, Sequence | |||||
| from typing import Any, Optional | |||||
| from core.agent.entities import AgentInvokeMessage | |||||
| from core.agent.plugin_entities import AgentStrategyParameter | |||||
| class BaseAgentStrategy(ABC): | |||||
| """ | |||||
| Agent Strategy | |||||
| """ | |||||
| def invoke( | |||||
| self, | |||||
| params: dict[str, Any], | |||||
| user_id: str, | |||||
| conversation_id: Optional[str] = None, | |||||
| app_id: Optional[str] = None, | |||||
| message_id: Optional[str] = None, | |||||
| ) -> Generator[AgentInvokeMessage, None, None]: | |||||
| """ | |||||
| Invoke the agent strategy. | |||||
| """ | |||||
| yield from self._invoke(params, user_id, conversation_id, app_id, message_id) | |||||
| def get_parameters(self) -> Sequence[AgentStrategyParameter]: | |||||
| """ | |||||
| Get the parameters for the agent strategy. | |||||
| """ | |||||
| return [] | |||||
| @abstractmethod | |||||
| def _invoke( | |||||
| self, | |||||
| params: dict[str, Any], | |||||
| user_id: str, | |||||
| conversation_id: Optional[str] = None, | |||||
| app_id: Optional[str] = None, | |||||
| message_id: Optional[str] = None, | |||||
| ) -> Generator[AgentInvokeMessage, None, None]: | |||||
| pass | 
| from collections.abc import Generator, Sequence | |||||
| from typing import Any, Optional | |||||
| from core.agent.entities import AgentInvokeMessage | |||||
| from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter | |||||
| from core.agent.strategy.base import BaseAgentStrategy | |||||
| from core.plugin.manager.agent import PluginAgentManager | |||||
| from core.plugin.utils.converter import convert_parameters_to_plugin_format | |||||
| class PluginAgentStrategy(BaseAgentStrategy): | |||||
| """ | |||||
| Agent Strategy | |||||
| """ | |||||
| tenant_id: str | |||||
| declaration: AgentStrategyEntity | |||||
| def __init__(self, tenant_id: str, declaration: AgentStrategyEntity): | |||||
| self.tenant_id = tenant_id | |||||
| self.declaration = declaration | |||||
| def get_parameters(self) -> Sequence[AgentStrategyParameter]: | |||||
| return self.declaration.parameters | |||||
| def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]: | |||||
| """ | |||||
| Initialize the parameters for the agent strategy. | |||||
| """ | |||||
| for parameter in self.declaration.parameters: | |||||
| params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name)) | |||||
| return params | |||||
| def _invoke( | |||||
| self, | |||||
| params: dict[str, Any], | |||||
| user_id: str, | |||||
| conversation_id: Optional[str] = None, | |||||
| app_id: Optional[str] = None, | |||||
| message_id: Optional[str] = None, | |||||
| ) -> Generator[AgentInvokeMessage, None, None]: | |||||
| """ | |||||
| Invoke the agent strategy. | |||||
| """ | |||||
| manager = PluginAgentManager() | |||||
| initialized_params = self.initialize_parameters(params) | |||||
| params = convert_parameters_to_plugin_format(initialized_params) | |||||
| yield from manager.invoke( | |||||
| tenant_id=self.tenant_id, | |||||
| user_id=user_id, | |||||
| agent_provider=self.declaration.identity.provider, | |||||
| agent_strategy=self.declaration.identity.name, | |||||
| agent_params=params, | |||||
| conversation_id=conversation_id, | |||||
| app_id=app_id, | |||||
| message_id=message_id, | |||||
| ) | 
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | ||||
| from core.entities.model_entities import ModelStatus | from core.entities.model_entities import ModelStatus | ||||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | ||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.model_runtime.entities.llm_entities import LLMMode | |||||
| from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.provider_manager import ProviderManager | from core.provider_manager import ProviderManager | ||||
| stop = completion_params["stop"] | stop = completion_params["stop"] | ||||
| del completion_params["stop"] | del completion_params["stop"] | ||||
| model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) | |||||
| # get model mode | # get model mode | ||||
| model_mode = model_config.mode | model_mode = model_config.mode | ||||
| if not model_mode: | if not model_mode: | ||||
| mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials) | |||||
| model_mode = mode_enum.value | |||||
| model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) | |||||
| model_mode = LLMMode.CHAT.value | |||||
| if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE): | |||||
| model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value | |||||
| if not model_schema: | if not model_schema: | ||||
| raise ValueError(f"Model {model_name} not exist.") | raise ValueError(f"Model {model_name} not exist.") | 
| from typing import Any | from typing import Any | ||||
| from core.app.app_config.entities import ModelConfigEntity | from core.app.app_config.entities import ModelConfigEntity | ||||
| from core.entities import DEFAULT_PLUGIN_ID | |||||
| from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | ||||
| from core.model_runtime.model_providers import model_provider_factory | |||||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||||
| from core.provider_manager import ProviderManager | from core.provider_manager import ProviderManager | ||||
| raise ValueError("model must be of object type") | raise ValueError("model must be of object type") | ||||
| # model.provider | # model.provider | ||||
| model_provider_factory = ModelProviderFactory(tenant_id) | |||||
| provider_entities = model_provider_factory.get_providers() | provider_entities = model_provider_factory.get_providers() | ||||
| model_provider_names = [provider.provider for provider in provider_entities] | model_provider_names = [provider.provider for provider in provider_entities] | ||||
| if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: | |||||
| if "provider" not in config["model"]: | |||||
| raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") | |||||
| if "/" not in config["model"]["provider"]: | |||||
| config["model"]["provider"] = ( | |||||
| f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}" | |||||
| ) | |||||
| if config["model"]["provider"] not in model_provider_names: | |||||
| raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") | raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") | ||||
| # model.name | # model.name | 
| user: Union[Account, EndUser], | user: Union[Account, EndUser], | ||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: Literal[True], | |||||
| ) -> Generator[str, None, None]: ... | |||||
| streaming: Literal[False], | |||||
| ) -> Mapping[str, Any]: ... | |||||
| @overload | @overload | ||||
| def generate( | def generate( | ||||
| app_model: App, | app_model: App, | ||||
| workflow: Workflow, | workflow: Workflow, | ||||
| user: Union[Account, EndUser], | user: Union[Account, EndUser], | ||||
| args: Mapping[str, Any], | |||||
| args: Mapping, | |||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: Literal[False], | |||||
| ) -> Mapping[str, Any]: ... | |||||
| streaming: Literal[True], | |||||
| ) -> Generator[Mapping | str, None, None]: ... | |||||
| @overload | @overload | ||||
| def generate( | def generate( | ||||
| app_model: App, | app_model: App, | ||||
| workflow: Workflow, | workflow: Workflow, | ||||
| user: Union[Account, EndUser], | user: Union[Account, EndUser], | ||||
| args: Mapping[str, Any], | |||||
| args: Mapping, | |||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: bool = True, | |||||
| ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... | |||||
| streaming: bool, | |||||
| ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ... | |||||
| def generate( | def generate( | ||||
| self, | self, | ||||
| app_model: App, | app_model: App, | ||||
| workflow: Workflow, | workflow: Workflow, | ||||
| user: Union[Account, EndUser], | user: Union[Account, EndUser], | ||||
| args: Mapping[str, Any], | |||||
| args: Mapping, | |||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: bool = True, | streaming: bool = True, | ||||
| ): | |||||
| ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: | |||||
| """ | """ | ||||
| Generate App response. | Generate App response. | ||||
| workflow_run_id=workflow_run_id, | workflow_run_id=workflow_run_id, | ||||
| ) | ) | ||||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | ||||
| contexts.plugin_tool_providers.set({}) | |||||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||||
| return self._generate( | return self._generate( | ||||
| workflow=workflow, | workflow=workflow, | ||||
| ) | ) | ||||
| def single_iteration_generate( | def single_iteration_generate( | ||||
| self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True | |||||
| ) -> Mapping[str, Any] | Generator[str, None, None]: | |||||
| self, | |||||
| app_model: App, | |||||
| workflow: Workflow, | |||||
| node_id: str, | |||||
| user: Account | EndUser, | |||||
| args: Mapping, | |||||
| streaming: bool = True, | |||||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: | |||||
| """ | """ | ||||
| Generate App response. | Generate App response. | ||||
| ), | ), | ||||
| ) | ) | ||||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | ||||
| contexts.plugin_tool_providers.set({}) | |||||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||||
| return self._generate( | return self._generate( | ||||
| workflow=workflow, | workflow=workflow, | ||||
| application_generate_entity: AdvancedChatAppGenerateEntity, | application_generate_entity: AdvancedChatAppGenerateEntity, | ||||
| conversation: Optional[Conversation] = None, | conversation: Optional[Conversation] = None, | ||||
| stream: bool = True, | stream: bool = True, | ||||
| ) -> Mapping[str, Any] | Generator[str, None, None]: | |||||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: | |||||
| """ | """ | ||||
| Generate App response. | Generate App response. | ||||
| class AppGeneratorTTSPublisher: | class AppGeneratorTTSPublisher: | ||||
| def __init__(self, tenant_id: str, voice: str): | |||||
| def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None): | |||||
| self.logger = logging.getLogger(__name__) | self.logger = logging.getLogger(__name__) | ||||
| self.tenant_id = tenant_id | self.tenant_id = tenant_id | ||||
| self.msg_text = "" | self.msg_text = "" | ||||
| self.model_instance = self.model_manager.get_default_model_instance( | self.model_instance = self.model_manager.get_default_model_instance( | ||||
| tenant_id=self.tenant_id, model_type=ModelType.TTS | tenant_id=self.tenant_id, model_type=ModelType.TTS | ||||
| ) | ) | ||||
| self.voices = self.model_instance.get_tts_voices() | |||||
| self.voices = self.model_instance.get_tts_voices(language=language) | |||||
| values = [voice.get("value") for voice in self.voices] | values = [voice.get("value") for voice in self.voices] | ||||
| self.voice = voice | self.voice = voice | ||||
| if not voice or voice not in values: | if not voice or voice not in values: | 
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | ||||
| workflow=workflow, | workflow=workflow, | ||||
| node_id=self.application_generate_entity.single_iteration_run.node_id, | node_id=self.application_generate_entity.single_iteration_run.node_id, | ||||
| user_inputs=self.application_generate_entity.single_iteration_run.inputs, | |||||
| user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), | |||||
| ) | ) | ||||
| else: | else: | ||||
| inputs = self.application_generate_entity.inputs | inputs = self.application_generate_entity.inputs | 
| import json | |||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import Any, cast | from typing import Any, cast | ||||
| @classmethod | @classmethod | ||||
| def convert_stream_full_response( | def convert_stream_full_response( | ||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | cls, stream_response: Generator[AppStreamResponse, None, None] | ||||
| ) -> Generator[str, Any, None]: | |||||
| ) -> Generator[dict | str, Any, None]: | |||||
| """ | """ | ||||
| Convert stream full response. | Convert stream full response. | ||||
| :param stream_response: stream response | :param stream_response: stream response | ||||
| response_chunk.update(data) | response_chunk.update(data) | ||||
| else: | else: | ||||
| response_chunk.update(sub_stream_response.to_dict()) | response_chunk.update(sub_stream_response.to_dict()) | ||||
| yield json.dumps(response_chunk) | |||||
| yield response_chunk | |||||
| @classmethod | @classmethod | ||||
| def convert_stream_simple_response( | def convert_stream_simple_response( | ||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | cls, stream_response: Generator[AppStreamResponse, None, None] | ||||
| ) -> Generator[str, Any, None]: | |||||
| ) -> Generator[dict | str, Any, None]: | |||||
| """ | """ | ||||
| Convert stream simple response. | Convert stream simple response. | ||||
| :param stream_response: stream response | :param stream_response: stream response | ||||
| else: | else: | ||||
| response_chunk.update(sub_stream_response.to_dict()) | response_chunk.update(sub_stream_response.to_dict()) | ||||
| yield json.dumps(response_chunk) | |||||
| yield response_chunk | 
| ) | ) | ||||
| from core.app.entities.queue_entities import ( | from core.app.entities.queue_entities import ( | ||||
| QueueAdvancedChatMessageEndEvent, | QueueAdvancedChatMessageEndEvent, | ||||
| QueueAgentLogEvent, | |||||
| QueueAnnotationReplyEvent, | QueueAnnotationReplyEvent, | ||||
| QueueErrorEvent, | QueueErrorEvent, | ||||
| QueueIterationCompletedEvent, | QueueIterationCompletedEvent, | ||||
| and features_dict["text_to_speech"].get("enabled") | and features_dict["text_to_speech"].get("enabled") | ||||
| and features_dict["text_to_speech"].get("autoPlay") == "enabled" | and features_dict["text_to_speech"].get("autoPlay") == "enabled" | ||||
| ): | ): | ||||
| tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) | |||||
| tts_publisher = AppGeneratorTTSPublisher( | |||||
| tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language") | |||||
| ) | |||||
| for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): | for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): | ||||
| while True: | while True: | ||||
| else: | else: | ||||
| start_listener_time = time.time() | start_listener_time = time.time() | ||||
| yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) | yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| logger.exception(f"Failed to listen audio message, task_id: {task_id}") | logger.exception(f"Failed to listen audio message, task_id: {task_id}") | ||||
| break | break | ||||
| if tts_publisher: | if tts_publisher: | ||||
| session.commit() | session.commit() | ||||
| yield self._message_end_to_stream_response() | yield self._message_end_to_stream_response() | ||||
| elif isinstance(event, QueueAgentLogEvent): | |||||
| yield self._workflow_cycle_manager._handle_agent_log( | |||||
| task_id=self._application_generate_entity.task_id, event=event | |||||
| ) | |||||
| else: | else: | ||||
| continue | continue | ||||
| import contextvars | |||||
| import logging | import logging | ||||
| import threading | import threading | ||||
| import uuid | import uuid | ||||
| user: Union[Account, EndUser], | user: Union[Account, EndUser], | ||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: Literal[True], | |||||
| ) -> Generator[str, None, None]: ... | |||||
| streaming: Literal[False], | |||||
| ) -> Mapping[str, Any]: ... | |||||
| @overload | @overload | ||||
| def generate( | def generate( | ||||
| user: Union[Account, EndUser], | user: Union[Account, EndUser], | ||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: Literal[False], | |||||
| ) -> Mapping[str, Any]: ... | |||||
| streaming: Literal[True], | |||||
| ) -> Generator[Mapping | str, None, None]: ... | |||||
| @overload | @overload | ||||
| def generate( | def generate( | ||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: bool, | streaming: bool, | ||||
| ) -> Mapping[str, Any] | Generator[str, None, None]: ... | |||||
| ) -> Union[Mapping, Generator[Mapping | str, None, None]]: ... | |||||
| def generate( | def generate( | ||||
| self, | self, | ||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: bool = True, | streaming: bool = True, | ||||
| ): | |||||
| ) -> Union[Mapping, Generator[Mapping | str, None, None]]: | |||||
| """ | """ | ||||
| Generate App response. | Generate App response. | ||||
| target=self._generate_worker, | target=self._generate_worker, | ||||
| kwargs={ | kwargs={ | ||||
| "flask_app": current_app._get_current_object(), # type: ignore | "flask_app": current_app._get_current_object(), # type: ignore | ||||
| "context": contextvars.copy_context(), | |||||
| "application_generate_entity": application_generate_entity, | "application_generate_entity": application_generate_entity, | ||||
| "queue_manager": queue_manager, | "queue_manager": queue_manager, | ||||
| "conversation_id": conversation.id, | "conversation_id": conversation.id, | ||||
| def _generate_worker( | def _generate_worker( | ||||
| self, | self, | ||||
| flask_app: Flask, | flask_app: Flask, | ||||
| context: contextvars.Context, | |||||
| application_generate_entity: AgentChatAppGenerateEntity, | application_generate_entity: AgentChatAppGenerateEntity, | ||||
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| conversation_id: str, | conversation_id: str, | ||||
| :param message_id: message ID | :param message_id: message ID | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| for var, val in context.items(): | |||||
| var.set(val) | |||||
| with flask_app.app_context(): | with flask_app.app_context(): | ||||
| try: | try: | ||||
| # get conversation and message | # get conversation and message | 
| from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig | from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig | ||||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | ||||
| from core.app.apps.base_app_runner import AppRunner | from core.app.apps.base_app_runner import AppRunner | ||||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity | |||||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity | |||||
| from core.app.entities.queue_entities import QueueAnnotationReplyEvent | from core.app.entities.queue_entities import QueueAnnotationReplyEvent | ||||
| from core.memory.token_buffer_memory import TokenBufferMemory | from core.memory.token_buffer_memory import TokenBufferMemory | ||||
| from core.model_manager import ModelInstance | from core.model_manager import ModelInstance | ||||
| from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage | |||||
| from core.model_runtime.entities.llm_entities import LLMMode | |||||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey | from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.moderation.base import ModerationError | from core.moderation.base import ModerationError | ||||
| from core.tools.entities.tool_entities import ToolRuntimeVariablePool | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import App, Conversation, Message, MessageAgentThought | |||||
| from models.tools import ToolConversationVariables | |||||
| from models.model import App, Conversation, Message | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| app_record=app_record, | app_record=app_record, | ||||
| model_config=application_generate_entity.model_conf, | model_config=application_generate_entity.model_conf, | ||||
| prompt_template_entity=app_config.prompt_template, | prompt_template_entity=app_config.prompt_template, | ||||
| inputs=inputs, | |||||
| files=files, | |||||
| inputs=dict(inputs), | |||||
| files=list(files), | |||||
| query=query, | query=query, | ||||
| ) | ) | ||||
| app_record=app_record, | app_record=app_record, | ||||
| model_config=application_generate_entity.model_conf, | model_config=application_generate_entity.model_conf, | ||||
| prompt_template_entity=app_config.prompt_template, | prompt_template_entity=app_config.prompt_template, | ||||
| inputs=inputs, | |||||
| files=files, | |||||
| inputs=dict(inputs), | |||||
| files=list(files), | |||||
| query=query, | query=query, | ||||
| memory=memory, | memory=memory, | ||||
| ) | ) | ||||
| app_id=app_record.id, | app_id=app_record.id, | ||||
| tenant_id=app_config.tenant_id, | tenant_id=app_config.tenant_id, | ||||
| app_generate_entity=application_generate_entity, | app_generate_entity=application_generate_entity, | ||||
| inputs=inputs, | |||||
| query=query, | |||||
| inputs=dict(inputs), | |||||
| query=query or "", | |||||
| message_id=message.id, | message_id=message.id, | ||||
| ) | ) | ||||
| except ModerationError as e: | except ModerationError as e: | ||||
| app_record=app_record, | app_record=app_record, | ||||
| model_config=application_generate_entity.model_conf, | model_config=application_generate_entity.model_conf, | ||||
| prompt_template_entity=app_config.prompt_template, | prompt_template_entity=app_config.prompt_template, | ||||
| inputs=inputs, | |||||
| files=files, | |||||
| query=query, | |||||
| inputs=dict(inputs), | |||||
| files=list(files), | |||||
| query=query or "", | |||||
| memory=memory, | memory=memory, | ||||
| ) | ) | ||||
| return | return | ||||
| agent_entity = app_config.agent | agent_entity = app_config.agent | ||||
| if not agent_entity: | |||||
| raise ValueError("Agent entity not found") | |||||
| # load tool variables | |||||
| tool_conversation_variables = self._load_tool_variables( | |||||
| conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id | |||||
| ) | |||||
| # convert db variables to tool variables | |||||
| tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) | |||||
| assert agent_entity is not None | |||||
| # init model instance | # init model instance | ||||
| model_instance = ModelInstance( | model_instance = ModelInstance( | ||||
| app_record=app_record, | app_record=app_record, | ||||
| model_config=application_generate_entity.model_conf, | model_config=application_generate_entity.model_conf, | ||||
| prompt_template_entity=app_config.prompt_template, | prompt_template_entity=app_config.prompt_template, | ||||
| inputs=inputs, | |||||
| files=files, | |||||
| query=query, | |||||
| inputs=dict(inputs), | |||||
| files=list(files), | |||||
| query=query or "", | |||||
| memory=memory, | memory=memory, | ||||
| ) | ) | ||||
| user_id=application_generate_entity.user_id, | user_id=application_generate_entity.user_id, | ||||
| memory=memory, | memory=memory, | ||||
| prompt_messages=prompt_message, | prompt_messages=prompt_message, | ||||
| variables_pool=tool_variables, | |||||
| db_variables=tool_conversation_variables, | |||||
| model_instance=model_instance, | model_instance=model_instance, | ||||
| ) | ) | ||||
| stream=application_generate_entity.stream, | stream=application_generate_entity.stream, | ||||
| agent=True, | agent=True, | ||||
| ) | ) | ||||
| def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: | |||||
| """ | |||||
| load tool variables from database | |||||
| """ | |||||
| tool_variables: ToolConversationVariables | None = ( | |||||
| db.session.query(ToolConversationVariables) | |||||
| .filter( | |||||
| ToolConversationVariables.conversation_id == conversation_id, | |||||
| ToolConversationVariables.tenant_id == tenant_id, | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| if tool_variables: | |||||
| # save tool variables to session, so that we can update it later | |||||
| db.session.add(tool_variables) | |||||
| else: | |||||
| # create new tool variables | |||||
| tool_variables = ToolConversationVariables( | |||||
| conversation_id=conversation_id, | |||||
| user_id=user_id, | |||||
| tenant_id=tenant_id, | |||||
| variables_str="[]", | |||||
| ) | |||||
| db.session.add(tool_variables) | |||||
| db.session.commit() | |||||
| return tool_variables | |||||
| def _convert_db_variables_to_tool_variables( | |||||
| self, db_variables: ToolConversationVariables | |||||
| ) -> ToolRuntimeVariablePool: | |||||
| """ | |||||
| convert db variables to tool variables | |||||
| """ | |||||
| return ToolRuntimeVariablePool( | |||||
| **{ | |||||
| "conversation_id": db_variables.conversation_id, | |||||
| "user_id": db_variables.user_id, | |||||
| "tenant_id": db_variables.tenant_id, | |||||
| "pool": db_variables.variables, | |||||
| } | |||||
| ) | |||||
| def _get_usage_of_all_agent_thoughts( | |||||
| self, model_config: ModelConfigWithCredentialsEntity, message: Message | |||||
| ) -> LLMUsage: | |||||
| """ | |||||
| Get usage of all agent thoughts | |||||
| :param model_config: model config | |||||
| :param message: message | |||||
| :return: | |||||
| """ | |||||
| agent_thoughts = ( | |||||
| db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all() | |||||
| ) | |||||
| all_message_tokens = 0 | |||||
| all_answer_tokens = 0 | |||||
| for agent_thought in agent_thoughts: | |||||
| all_message_tokens += agent_thought.message_tokens | |||||
| all_answer_tokens += agent_thought.answer_tokens | |||||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||||
| return model_type_instance._calc_response_usage( | |||||
| model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens | |||||
| ) | 
| import json | |||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import cast | from typing import cast | ||||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | ||||
| from core.app.entities.task_entities import ( | from core.app.entities.task_entities import ( | ||||
| AppStreamResponse, | |||||
| ChatbotAppBlockingResponse, | ChatbotAppBlockingResponse, | ||||
| ChatbotAppStreamResponse, | ChatbotAppStreamResponse, | ||||
| ErrorStreamResponse, | ErrorStreamResponse, | ||||
| return response | return response | ||||
| @classmethod | @classmethod | ||||
| def convert_stream_full_response( # type: ignore[override] | |||||
| cls, | |||||
| stream_response: Generator[ChatbotAppStreamResponse, None, None], | |||||
| ) -> Generator[str, None, None]: | |||||
| def convert_stream_full_response( | |||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||||
| ) -> Generator[dict | str, None, None]: | |||||
| """ | """ | ||||
| Convert stream full response. | Convert stream full response. | ||||
| :param stream_response: stream response | :param stream_response: stream response | ||||
| response_chunk.update(data) | response_chunk.update(data) | ||||
| else: | else: | ||||
| response_chunk.update(sub_stream_response.to_dict()) | response_chunk.update(sub_stream_response.to_dict()) | ||||
| yield json.dumps(response_chunk) | |||||
| yield response_chunk | |||||
| @classmethod | @classmethod | ||||
| def convert_stream_simple_response( # type: ignore[override] | |||||
| cls, | |||||
| stream_response: Generator[ChatbotAppStreamResponse, None, None], | |||||
| ) -> Generator[str, None, None]: | |||||
| def convert_stream_simple_response( | |||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||||
| ) -> Generator[dict | str, None, None]: | |||||
| """ | """ | ||||
| Convert stream simple response. | Convert stream simple response. | ||||
| :param stream_response: stream response | :param stream_response: stream response | ||||
| else: | else: | ||||
| response_chunk.update(sub_stream_response.to_dict()) | response_chunk.update(sub_stream_response.to_dict()) | ||||
| yield json.dumps(response_chunk) | |||||
| yield response_chunk | 
| @classmethod | @classmethod | ||||
| def convert( | def convert( | ||||
| cls, | |||||
| response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], | |||||
| invoke_from: InvokeFrom, | |||||
| ) -> Mapping[str, Any] | Generator[str, None, None]: | |||||
| cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom | |||||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: | |||||
| if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: | if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: | ||||
| if isinstance(response, AppBlockingResponse): | if isinstance(response, AppBlockingResponse): | ||||
| return cls.convert_blocking_full_response(response) | return cls.convert_blocking_full_response(response) | ||||
| else: | else: | ||||
| def _generate_full_response() -> Generator[str, Any, None]: | |||||
| for chunk in cls.convert_stream_full_response(response): | |||||
| if chunk == "ping": | |||||
| yield f"event: {chunk}\n\n" | |||||
| else: | |||||
| yield f"data: {chunk}\n\n" | |||||
| def _generate_full_response() -> Generator[dict | str, Any, None]: | |||||
| yield from cls.convert_stream_full_response(response) | |||||
| return _generate_full_response() | return _generate_full_response() | ||||
| else: | else: | ||||
| return cls.convert_blocking_simple_response(response) | return cls.convert_blocking_simple_response(response) | ||||
| else: | else: | ||||
| def _generate_simple_response() -> Generator[str, Any, None]: | |||||
| for chunk in cls.convert_stream_simple_response(response): | |||||
| if chunk == "ping": | |||||
| yield f"event: {chunk}\n\n" | |||||
| else: | |||||
| yield f"data: {chunk}\n\n" | |||||
| def _generate_simple_response() -> Generator[dict | str, Any, None]: | |||||
| yield from cls.convert_stream_simple_response(response) | |||||
| return _generate_simple_response() | return _generate_simple_response() | ||||
| @abstractmethod | @abstractmethod | ||||
| def convert_stream_full_response( | def convert_stream_full_response( | ||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | cls, stream_response: Generator[AppStreamResponse, None, None] | ||||
| ) -> Generator[str, None, None]: | |||||
| ) -> Generator[dict | str, None, None]: | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @classmethod | @classmethod | ||||
| @abstractmethod | @abstractmethod | ||||
| def convert_stream_simple_response( | def convert_stream_simple_response( | ||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | cls, stream_response: Generator[AppStreamResponse, None, None] | ||||
| ) -> Generator[str, None, None]: | |||||
| ) -> Generator[dict | str, None, None]: | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @classmethod | @classmethod | 
| from collections.abc import Mapping, Sequence | |||||
| from typing import TYPE_CHECKING, Any, Optional | |||||
| import json | |||||
| from collections.abc import Generator, Mapping, Sequence | |||||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||||
| from core.app.app_config.entities import VariableEntityType | from core.app.app_config.entities import VariableEntityType | ||||
| from core.file import File, FileUploadConfig | from core.file import File, FileUploadConfig | ||||
| if isinstance(value, str): | if isinstance(value, str): | ||||
| return value.replace("\x00", "") | return value.replace("\x00", "") | ||||
| return value | return value | ||||
| @classmethod | |||||
| def convert_to_event_stream(cls, generator: Union[Mapping, Generator[Mapping | str, None, None]]): | |||||
| """ | |||||
| Convert messages into event stream | |||||
| """ | |||||
| if isinstance(generator, dict): | |||||
| return generator | |||||
| else: | |||||
| def gen(): | |||||
| for message in generator: | |||||
| if isinstance(message, (Mapping, dict)): | |||||
| yield f"data: {json.dumps(message)}\n\n" | |||||
| else: | |||||
| yield f"event: {message}\n\n" | |||||
| return gen() | 
| import time | import time | ||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| from enum import Enum | from enum import Enum | ||||
| from typing import Any | |||||
| from typing import Any, Optional | |||||
| from sqlalchemy.orm import DeclarativeMeta | from sqlalchemy.orm import DeclarativeMeta | ||||
| Set task stop flag | Set task stop flag | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| result = redis_client.get(cls._generate_task_belong_cache_key(task_id)) | |||||
| result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id)) | |||||
| if result is None: | if result is None: | ||||
| return | return | ||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: Literal[True], | streaming: Literal[True], | ||||
| ) -> Generator[str, None, None]: ... | |||||
| ) -> Generator[Mapping | str, None, None]: ... | |||||
| @overload | @overload | ||||
| def generate( | def generate( | ||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: bool, | streaming: bool, | ||||
| ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... | |||||
| ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ... | |||||
| def generate( | def generate( | ||||
| self, | self, | ||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: bool = True, | streaming: bool = True, | ||||
| ): | |||||
| ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: | |||||
| """ | """ | ||||
| Generate App response. | Generate App response. | ||||
| import json | |||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import cast | from typing import cast | ||||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | ||||
| from core.app.entities.task_entities import ( | from core.app.entities.task_entities import ( | ||||
| AppStreamResponse, | |||||
| ChatbotAppBlockingResponse, | ChatbotAppBlockingResponse, | ||||
| ChatbotAppStreamResponse, | ChatbotAppStreamResponse, | ||||
| ErrorStreamResponse, | ErrorStreamResponse, | ||||
| @classmethod | @classmethod | ||||
| def convert_stream_full_response( | def convert_stream_full_response( | ||||
| cls, | |||||
| stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] | |||||
| ) -> Generator[str, None, None]: | |||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||||
| ) -> Generator[dict | str, None, None]: | |||||
| """ | """ | ||||
| Convert stream full response. | Convert stream full response. | ||||
| :param stream_response: stream response | :param stream_response: stream response | ||||
| response_chunk.update(data) | response_chunk.update(data) | ||||
| else: | else: | ||||
| response_chunk.update(sub_stream_response.to_dict()) | response_chunk.update(sub_stream_response.to_dict()) | ||||
| yield json.dumps(response_chunk) | |||||
| yield response_chunk | |||||
| @classmethod | @classmethod | ||||
| def convert_stream_simple_response( | def convert_stream_simple_response( | ||||
| cls, | |||||
| stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] | |||||
| ) -> Generator[str, None, None]: | |||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||||
| ) -> Generator[dict | str, None, None]: | |||||
| """ | """ | ||||
| Convert stream simple response. | Convert stream simple response. | ||||
| :param stream_response: stream response | :param stream_response: stream response | ||||
| else: | else: | ||||
| response_chunk.update(sub_stream_response.to_dict()) | response_chunk.update(sub_stream_response.to_dict()) | ||||
| yield json.dumps(response_chunk) | |||||
| yield response_chunk | 
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: Literal[True], | streaming: Literal[True], | ||||
| ) -> Generator[str, None, None]: ... | |||||
| ) -> Generator[str | Mapping[str, Any], None, None]: ... | |||||
| @overload | @overload | ||||
| def generate( | def generate( | ||||
| user: Union[Account, EndUser], | user: Union[Account, EndUser], | ||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: bool, | |||||
| ) -> Mapping[str, Any] | Generator[str, None, None]: ... | |||||
| streaming: bool = False, | |||||
| ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ... | |||||
| def generate( | def generate( | ||||
| self, | self, | ||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: bool = True, | streaming: bool = True, | ||||
| ): | |||||
| ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: | |||||
| """ | """ | ||||
| Generate App response. | Generate App response. | ||||
| user: Union[Account, EndUser], | user: Union[Account, EndUser], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| stream: bool = True, | stream: bool = True, | ||||
| ) -> Union[Mapping[str, Any], Generator[str, None, None]]: | |||||
| ) -> Union[Mapping, Generator[Mapping | str, None, None]]: | |||||
| """ | """ | ||||
| Generate App response. | Generate App response. | ||||
| import json | |||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import cast | from typing import cast | ||||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | ||||
| from core.app.entities.task_entities import ( | from core.app.entities.task_entities import ( | ||||
| AppStreamResponse, | |||||
| CompletionAppBlockingResponse, | CompletionAppBlockingResponse, | ||||
| CompletionAppStreamResponse, | CompletionAppStreamResponse, | ||||
| ErrorStreamResponse, | ErrorStreamResponse, | ||||
| @classmethod | @classmethod | ||||
| def convert_stream_full_response( | def convert_stream_full_response( | ||||
| cls, | |||||
| stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] | |||||
| ) -> Generator[str, None, None]: | |||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||||
| ) -> Generator[dict | str, None, None]: | |||||
| """ | """ | ||||
| Convert stream full response. | Convert stream full response. | ||||
| :param stream_response: stream response | :param stream_response: stream response | ||||
| response_chunk.update(data) | response_chunk.update(data) | ||||
| else: | else: | ||||
| response_chunk.update(sub_stream_response.to_dict()) | response_chunk.update(sub_stream_response.to_dict()) | ||||
| yield json.dumps(response_chunk) | |||||
| yield response_chunk | |||||
| @classmethod | @classmethod | ||||
| def convert_stream_simple_response( | def convert_stream_simple_response( | ||||
| cls, | |||||
| stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] | |||||
| ) -> Generator[str, None, None]: | |||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||||
| ) -> Generator[dict | str, None, None]: | |||||
| """ | """ | ||||
| Convert stream simple response. | Convert stream simple response. | ||||
| :param stream_response: stream response | :param stream_response: stream response | ||||
| else: | else: | ||||
| response_chunk.update(sub_stream_response.to_dict()) | response_chunk.update(sub_stream_response.to_dict()) | ||||
| yield json.dumps(response_chunk) | |||||
| yield response_chunk | 
| *, | *, | ||||
| app_model: App, | app_model: App, | ||||
| workflow: Workflow, | workflow: Workflow, | ||||
| user: Account | EndUser, | |||||
| user: Union[Account, EndUser], | |||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: Literal[True], | streaming: Literal[True], | ||||
| call_depth: int = 0, | |||||
| workflow_thread_pool_id: Optional[str] = None, | |||||
| ) -> Generator[str, None, None]: ... | |||||
| call_depth: int, | |||||
| workflow_thread_pool_id: Optional[str], | |||||
| ) -> Generator[Mapping | str, None, None]: ... | |||||
| @overload | @overload | ||||
| def generate( | def generate( | ||||
| *, | *, | ||||
| app_model: App, | app_model: App, | ||||
| workflow: Workflow, | workflow: Workflow, | ||||
| user: Account | EndUser, | |||||
| user: Union[Account, EndUser], | |||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: Literal[False], | streaming: Literal[False], | ||||
| call_depth: int = 0, | |||||
| workflow_thread_pool_id: Optional[str] = None, | |||||
| call_depth: int, | |||||
| workflow_thread_pool_id: Optional[str], | |||||
| ) -> Mapping[str, Any]: ... | ) -> Mapping[str, Any]: ... | ||||
| @overload | @overload | ||||
| *, | *, | ||||
| app_model: App, | app_model: App, | ||||
| workflow: Workflow, | workflow: Workflow, | ||||
| user: Account | EndUser, | |||||
| user: Union[Account, EndUser], | |||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: bool = True, | |||||
| call_depth: int = 0, | |||||
| workflow_thread_pool_id: Optional[str] = None, | |||||
| ) -> Mapping[str, Any] | Generator[str, None, None]: ... | |||||
| streaming: bool, | |||||
| call_depth: int, | |||||
| workflow_thread_pool_id: Optional[str], | |||||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... | |||||
| def generate( | def generate( | ||||
| self, | self, | ||||
| *, | *, | ||||
| app_model: App, | app_model: App, | ||||
| workflow: Workflow, | workflow: Workflow, | ||||
| user: Account | EndUser, | |||||
| user: Union[Account, EndUser], | |||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: bool = True, | streaming: bool = True, | ||||
| call_depth: int = 0, | call_depth: int = 0, | ||||
| workflow_thread_pool_id: Optional[str] = None, | workflow_thread_pool_id: Optional[str] = None, | ||||
| ): | |||||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: | |||||
| files: Sequence[Mapping[str, Any]] = args.get("files") or [] | files: Sequence[Mapping[str, Any]] = args.get("files") or [] | ||||
| # parse files | # parse files | ||||
| trace_manager=trace_manager, | trace_manager=trace_manager, | ||||
| workflow_run_id=workflow_run_id, | workflow_run_id=workflow_run_id, | ||||
| ) | ) | ||||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | ||||
| contexts.plugin_tool_providers.set({}) | |||||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||||
| return self._generate( | return self._generate( | ||||
| app_model=app_model, | app_model=app_model, | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| streaming: bool = True, | streaming: bool = True, | ||||
| workflow_thread_pool_id: Optional[str] = None, | workflow_thread_pool_id: Optional[str] = None, | ||||
| ) -> Mapping[str, Any] | Generator[str, None, None]: | |||||
| ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: | |||||
| """ | |||||
| Generate App response. | |||||
| :param app_model: App | |||||
| :param workflow: Workflow | |||||
| :param user: account or end user | |||||
| :param application_generate_entity: application generate entity | |||||
| :param invoke_from: invoke from source | |||||
| :param stream: is stream | |||||
| :param workflow_thread_pool_id: workflow thread pool id | |||||
| """ | |||||
| # init queue manager | # init queue manager | ||||
| queue_manager = WorkflowAppQueueManager( | queue_manager = WorkflowAppQueueManager( | ||||
| task_id=application_generate_entity.task_id, | task_id=application_generate_entity.task_id, | ||||
| app_model: App, | app_model: App, | ||||
| workflow: Workflow, | workflow: Workflow, | ||||
| node_id: str, | node_id: str, | ||||
| user: Account, | |||||
| user: Account | EndUser, | |||||
| args: Mapping[str, Any], | args: Mapping[str, Any], | ||||
| streaming: bool = True, | streaming: bool = True, | ||||
| ) -> Mapping[str, Any] | Generator[str, None, None]: | |||||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: | |||||
| """ | """ | ||||
| Generate App response. | Generate App response. | ||||
| workflow_run_id=str(uuid.uuid4()), | workflow_run_id=str(uuid.uuid4()), | ||||
| ) | ) | ||||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | ||||
| contexts.plugin_tool_providers.set({}) | |||||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||||
| return self._generate( | return self._generate( | ||||
| app_model=app_model, | app_model=app_model, | 
| import json | |||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import cast | from typing import cast | ||||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | ||||
| from core.app.entities.task_entities import ( | from core.app.entities.task_entities import ( | ||||
| AppStreamResponse, | |||||
| ErrorStreamResponse, | ErrorStreamResponse, | ||||
| NodeFinishStreamResponse, | NodeFinishStreamResponse, | ||||
| NodeStartStreamResponse, | NodeStartStreamResponse, | ||||
| @classmethod | @classmethod | ||||
| def convert_stream_full_response( | def convert_stream_full_response( | ||||
| cls, | |||||
| stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] | |||||
| ) -> Generator[str, None, None]: | |||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||||
| ) -> Generator[dict | str, None, None]: | |||||
| """ | """ | ||||
| Convert stream full response. | Convert stream full response. | ||||
| :param stream_response: stream response | :param stream_response: stream response | ||||
| response_chunk.update(data) | response_chunk.update(data) | ||||
| else: | else: | ||||
| response_chunk.update(sub_stream_response.to_dict()) | response_chunk.update(sub_stream_response.to_dict()) | ||||
| yield json.dumps(response_chunk) | |||||
| yield response_chunk | |||||
| @classmethod | @classmethod | ||||
| def convert_stream_simple_response( | def convert_stream_simple_response( | ||||
| cls, | |||||
| stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] | |||||
| ) -> Generator[str, None, None]: | |||||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||||
| ) -> Generator[dict | str, None, None]: | |||||
| """ | """ | ||||
| Convert stream simple response. | Convert stream simple response. | ||||
| :param stream_response: stream response | :param stream_response: stream response | ||||
| response_chunk.update(sub_stream_response.to_ignore_detail_dict()) | response_chunk.update(sub_stream_response.to_ignore_detail_dict()) | ||||
| else: | else: | ||||
| response_chunk.update(sub_stream_response.to_dict()) | response_chunk.update(sub_stream_response.to_dict()) | ||||
| yield json.dumps(response_chunk) | |||||
| yield response_chunk | 
| WorkflowAppGenerateEntity, | WorkflowAppGenerateEntity, | ||||
| ) | ) | ||||
| from core.app.entities.queue_entities import ( | from core.app.entities.queue_entities import ( | ||||
| QueueAgentLogEvent, | |||||
| QueueErrorEvent, | QueueErrorEvent, | ||||
| QueueIterationCompletedEvent, | QueueIterationCompletedEvent, | ||||
| QueueIterationNextEvent, | QueueIterationNextEvent, | ||||
| and features_dict["text_to_speech"].get("enabled") | and features_dict["text_to_speech"].get("enabled") | ||||
| and features_dict["text_to_speech"].get("autoPlay") == "enabled" | and features_dict["text_to_speech"].get("autoPlay") == "enabled" | ||||
| ): | ): | ||||
| tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) | |||||
| tts_publisher = AppGeneratorTTSPublisher( | |||||
| tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language") | |||||
| ) | |||||
| for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): | for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): | ||||
| while True: | while True: | ||||
| yield self._text_chunk_to_stream_response( | yield self._text_chunk_to_stream_response( | ||||
| delta_text, from_variable_selector=event.from_variable_selector | delta_text, from_variable_selector=event.from_variable_selector | ||||
| ) | ) | ||||
| elif isinstance(event, QueueAgentLogEvent): | |||||
| yield self._workflow_cycle_manager._handle_agent_log( | |||||
| task_id=self._application_generate_entity.task_id, event=event | |||||
| ) | |||||
| else: | else: | ||||
| continue | continue | ||||
| from core.app.apps.base_app_runner import AppRunner | from core.app.apps.base_app_runner import AppRunner | ||||
| from core.app.entities.queue_entities import ( | from core.app.entities.queue_entities import ( | ||||
| AppQueueEvent, | AppQueueEvent, | ||||
| QueueAgentLogEvent, | |||||
| QueueIterationCompletedEvent, | QueueIterationCompletedEvent, | ||||
| QueueIterationNextEvent, | QueueIterationNextEvent, | ||||
| QueueIterationStartEvent, | QueueIterationStartEvent, | ||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | from core.workflow.entities.node_entities import NodeRunMetadataKey | ||||
| from core.workflow.entities.variable_pool import VariablePool | from core.workflow.entities.variable_pool import VariablePool | ||||
| from core.workflow.graph_engine.entities.event import ( | from core.workflow.graph_engine.entities.event import ( | ||||
| AgentLogEvent, | |||||
| GraphEngineEvent, | GraphEngineEvent, | ||||
| GraphRunFailedEvent, | GraphRunFailedEvent, | ||||
| GraphRunPartialSucceededEvent, | GraphRunPartialSucceededEvent, | ||||
| predecessor_node_id=event.predecessor_node_id, | predecessor_node_id=event.predecessor_node_id, | ||||
| in_iteration_id=event.in_iteration_id, | in_iteration_id=event.in_iteration_id, | ||||
| parallel_mode_run_id=event.parallel_mode_run_id, | parallel_mode_run_id=event.parallel_mode_run_id, | ||||
| agent_strategy=event.agent_strategy, | |||||
| ) | ) | ||||
| ) | ) | ||||
| elif isinstance(event, NodeRunSucceededEvent): | elif isinstance(event, NodeRunSucceededEvent): | ||||
| retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id | retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id | ||||
| ) | ) | ||||
| ) | ) | ||||
| elif isinstance(event, AgentLogEvent): | |||||
| self._publish_event( | |||||
| QueueAgentLogEvent( | |||||
| id=event.id, | |||||
| label=event.label, | |||||
| node_execution_id=event.node_execution_id, | |||||
| parent_id=event.parent_id, | |||||
| error=event.error, | |||||
| status=event.status, | |||||
| data=event.data, | |||||
| metadata=event.metadata, | |||||
| ) | |||||
| ) | |||||
| elif isinstance(event, ParallelBranchRunStartedEvent): | elif isinstance(event, ParallelBranchRunStartedEvent): | ||||
| self._publish_event( | self._publish_event( | ||||
| QueueParallelBranchRunStartedEvent( | QueueParallelBranchRunStartedEvent( | 
| """ | """ | ||||
| node_id: str | node_id: str | ||||
| inputs: dict | |||||
| inputs: Mapping | |||||
| single_iteration_run: Optional[SingleIterationRunEntity] = None | single_iteration_run: Optional[SingleIterationRunEntity] = None | ||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | ||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey | |||||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | ||||
| from core.workflow.nodes import NodeType | from core.workflow.nodes import NodeType | ||||
| from core.workflow.nodes.base import BaseNodeData | from core.workflow.nodes.base import BaseNodeData | ||||
| PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" | PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" | ||||
| PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" | PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" | ||||
| PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" | PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" | ||||
| AGENT_LOG = "agent_log" | |||||
| ERROR = "error" | ERROR = "error" | ||||
| PING = "ping" | PING = "ping" | ||||
| STOP = "stop" | STOP = "stop" | ||||
| start_at: datetime | start_at: datetime | ||||
| parallel_mode_run_id: Optional[str] = None | parallel_mode_run_id: Optional[str] = None | ||||
| """iteratoin run in parallel mode run id""" | """iteratoin run in parallel mode run id""" | ||||
| agent_strategy: Optional[AgentNodeStrategyInit] = None | |||||
| class QueueNodeSucceededEvent(AppQueueEvent): | class QueueNodeSucceededEvent(AppQueueEvent): | ||||
| iteration_duration_map: Optional[dict[str, float]] = None | iteration_duration_map: Optional[dict[str, float]] = None | ||||
| class QueueAgentLogEvent(AppQueueEvent): | |||||
| """ | |||||
| QueueAgentLogEvent entity | |||||
| """ | |||||
| event: QueueEvent = QueueEvent.AGENT_LOG | |||||
| id: str | |||||
| label: str | |||||
| node_execution_id: str | |||||
| parent_id: str | None | |||||
| error: str | None | |||||
| status: str | |||||
| data: Mapping[str, Any] | |||||
| metadata: Optional[Mapping[str, Any]] = None | |||||
| class QueueNodeRetryEvent(QueueNodeStartedEvent): | class QueueNodeRetryEvent(QueueNodeStartedEvent): | ||||
| """QueueNodeRetryEvent entity""" | """QueueNodeRetryEvent entity""" | ||||
| from core.model_runtime.entities.llm_entities import LLMResult | from core.model_runtime.entities.llm_entities import LLMResult | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||||
| from models.workflow import WorkflowNodeExecutionStatus | from models.workflow import WorkflowNodeExecutionStatus | ||||
| ITERATION_COMPLETED = "iteration_completed" | ITERATION_COMPLETED = "iteration_completed" | ||||
| TEXT_CHUNK = "text_chunk" | TEXT_CHUNK = "text_chunk" | ||||
| TEXT_REPLACE = "text_replace" | TEXT_REPLACE = "text_replace" | ||||
| AGENT_LOG = "agent_log" | |||||
| class StreamResponse(BaseModel): | class StreamResponse(BaseModel): | ||||
| parent_parallel_start_node_id: Optional[str] = None | parent_parallel_start_node_id: Optional[str] = None | ||||
| iteration_id: Optional[str] = None | iteration_id: Optional[str] = None | ||||
| parallel_run_id: Optional[str] = None | parallel_run_id: Optional[str] = None | ||||
| agent_strategy: Optional[AgentNodeStrategyInit] = None | |||||
| event: StreamEvent = StreamEvent.NODE_STARTED | event: StreamEvent = StreamEvent.NODE_STARTED | ||||
| workflow_run_id: str | workflow_run_id: str | ||||
| workflow_run_id: str | workflow_run_id: str | ||||
| data: Data | data: Data | ||||
| class AgentLogStreamResponse(StreamResponse): | |||||
| """ | |||||
| AgentLogStreamResponse entity | |||||
| """ | |||||
| class Data(BaseModel): | |||||
| """ | |||||
| Data entity | |||||
| """ | |||||
| node_execution_id: str | |||||
| id: str | |||||
| label: str | |||||
| parent_id: str | None | |||||
| error: str | None | |||||
| status: str | |||||
| data: Mapping[str, Any] | |||||
| metadata: Optional[Mapping[str, Any]] = None | |||||
| event: StreamEvent = StreamEvent.AGENT_LOG | |||||
| data: Data | 
| if isinstance(prompt_message.content, str): | if isinstance(prompt_message.content, str): | ||||
| text += prompt_message.content + "\n" | text += prompt_message.content + "\n" | ||||
| moderation_result = moderation.check_moderation(model_config, text) | |||||
| moderation_result = moderation.check_moderation( | |||||
| tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text | |||||
| ) | |||||
| return moderation_result | return moderation_result | 
| and text_to_speech_dict.get("autoPlay") == "enabled" | and text_to_speech_dict.get("autoPlay") == "enabled" | ||||
| and text_to_speech_dict.get("enabled") | and text_to_speech_dict.get("enabled") | ||||
| ): | ): | ||||
| publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) | |||||
| publisher = AppGeneratorTTSPublisher( | |||||
| tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None) | |||||
| ) | |||||
| for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): | for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): | ||||
| while True: | while True: | ||||
| audio_response = self._listen_audio_msg(publisher, task_id) | audio_response = self._listen_audio_msg(publisher, task_id) | 
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity | from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity | ||||
| from core.app.entities.queue_entities import ( | from core.app.entities.queue_entities import ( | ||||
| QueueAgentLogEvent, | |||||
| QueueIterationCompletedEvent, | QueueIterationCompletedEvent, | ||||
| QueueIterationNextEvent, | QueueIterationNextEvent, | ||||
| QueueIterationStartEvent, | QueueIterationStartEvent, | ||||
| QueueParallelBranchRunSucceededEvent, | QueueParallelBranchRunSucceededEvent, | ||||
| ) | ) | ||||
| from core.app.entities.task_entities import ( | from core.app.entities.task_entities import ( | ||||
| AgentLogStreamResponse, | |||||
| IterationNodeCompletedStreamResponse, | IterationNodeCompletedStreamResponse, | ||||
| IterationNodeNextStreamResponse, | IterationNodeNextStreamResponse, | ||||
| IterationNodeStartStreamResponse, | IterationNodeStartStreamResponse, | ||||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | inputs = WorkflowEntry.handle_special_values(event.inputs) | ||||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | process_data = WorkflowEntry.handle_special_values(event.process_data) | ||||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | outputs = WorkflowEntry.handle_special_values(event.outputs) | ||||
| execution_metadata = ( | |||||
| json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None | |||||
| ) | |||||
| execution_metadata_dict = dict(event.execution_metadata or {}) | |||||
| execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None | |||||
| finished_at = datetime.now(UTC).replace(tzinfo=None) | finished_at = datetime.now(UTC).replace(tzinfo=None) | ||||
| elapsed_time = (finished_at - event.start_at).total_seconds() | elapsed_time = (finished_at - event.start_at).total_seconds() | ||||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | parent_parallel_start_node_id=event.parent_parallel_start_node_id, | ||||
| iteration_id=event.in_iteration_id, | iteration_id=event.in_iteration_id, | ||||
| parallel_run_id=event.parallel_mode_run_id, | parallel_run_id=event.parallel_mode_run_id, | ||||
| agent_strategy=event.agent_strategy, | |||||
| ), | ), | ||||
| ) | ) | ||||
| raise ValueError(f"Workflow node execution not found: {node_execution_id}") | raise ValueError(f"Workflow node execution not found: {node_execution_id}") | ||||
| cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] | cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] | ||||
| return session.merge(cached_workflow_node_execution) | return session.merge(cached_workflow_node_execution) | ||||
| def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: | |||||
| """ | |||||
| Handle agent log | |||||
| :param task_id: task id | |||||
| :param event: agent log event | |||||
| :return: | |||||
| """ | |||||
| return AgentLogStreamResponse( | |||||
| task_id=task_id, | |||||
| data=AgentLogStreamResponse.Data( | |||||
| node_execution_id=event.node_execution_id, | |||||
| id=event.id, | |||||
| parent_id=event.parent_id, | |||||
| label=event.label, | |||||
| error=event.error, | |||||
| status=event.status, | |||||
| data=event.data, | |||||
| metadata=event.metadata, | |||||
| ), | |||||
| ) | 
| from collections.abc import Mapping, Sequence | |||||
| from collections.abc import Iterable, Mapping | |||||
| from typing import Any, Optional, TextIO, Union | from typing import Any, Optional, TextIO, Union | ||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| self, | self, | ||||
| tool_name: str, | tool_name: str, | ||||
| tool_inputs: Mapping[str, Any], | tool_inputs: Mapping[str, Any], | ||||
| tool_outputs: Sequence[ToolInvokeMessage] | str, | |||||
| tool_outputs: Iterable[ToolInvokeMessage] | str, | |||||
| message_id: Optional[str] = None, | message_id: Optional[str] = None, | ||||
| timer: Optional[Any] = None, | timer: Optional[Any] = None, | ||||
| trace_manager: Optional[TraceQueueManager] = None, | trace_manager: Optional[TraceQueueManager] = None, | 
| from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler | |||||
| from collections.abc import Generator, Iterable, Mapping | |||||
| from typing import Any, Optional | |||||
| from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text | |||||
| from core.ops.ops_trace_manager import TraceQueueManager | |||||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||||
| class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): | class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): | ||||
| """Callback Handler that prints to std out.""" | """Callback Handler that prints to std out.""" | ||||
| def on_tool_execution( | |||||
| self, | |||||
| tool_name: str, | |||||
| tool_inputs: Mapping[str, Any], | |||||
| tool_outputs: Iterable[ToolInvokeMessage], | |||||
| message_id: Optional[str] = None, | |||||
| timer: Optional[Any] = None, | |||||
| trace_manager: Optional[TraceQueueManager] = None, | |||||
| ) -> Generator[ToolInvokeMessage, None, None]: | |||||
| for tool_output in tool_outputs: | |||||
| print_text("\n[on_tool_execution]\n", color=self.color) | |||||
| print_text("Tool: " + tool_name + "\n", color=self.color) | |||||
| print_text("Outputs: " + tool_output.model_dump_json()[:1000] + "\n", color=self.color) | |||||
| print_text("\n") | |||||
| yield tool_output | 
| DEFAULT_PLUGIN_ID = "langgenius" | 
| from enum import StrEnum | |||||
| class CommonParameterType(StrEnum): | |||||
| SECRET_INPUT = "secret-input" | |||||
| TEXT_INPUT = "text-input" | |||||
| SELECT = "select" | |||||
| STRING = "string" | |||||
| NUMBER = "number" | |||||
| FILE = "file" | |||||
| FILES = "files" | |||||
| SYSTEM_FILES = "system-files" | |||||
| BOOLEAN = "boolean" | |||||
| APP_SELECTOR = "app-selector" | |||||
| MODEL_SELECTOR = "model-selector" | |||||
| TOOLS_SELECTOR = "array[tools]" | |||||
| # TOOL_SELECTOR = "tool-selector" | |||||
| class AppSelectorScope(StrEnum): | |||||
| ALL = "all" | |||||
| CHAT = "chat" | |||||
| WORKFLOW = "workflow" | |||||
| COMPLETION = "completion" | |||||
| class ModelSelectorScope(StrEnum): | |||||
| LLM = "llm" | |||||
| TEXT_EMBEDDING = "text-embedding" | |||||
| RERANK = "rerank" | |||||
| TTS = "tts" | |||||
| SPEECH2TEXT = "speech2text" | |||||
| MODERATION = "moderation" | |||||
| VISION = "vision" | |||||
| class ToolSelectorScope(StrEnum): | |||||
| ALL = "all" | |||||
| CUSTOM = "custom" | |||||
| BUILTIN = "builtin" | |||||
| WORKFLOW = "workflow" | 
| import json | import json | ||||
| import logging | import logging | ||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from collections.abc import Iterator | |||||
| from collections.abc import Iterator, Sequence | |||||
| from json import JSONDecodeError | from json import JSONDecodeError | ||||
| from typing import Optional | from typing import Optional | ||||
| from pydantic import BaseModel, ConfigDict | from pydantic import BaseModel, ConfigDict | ||||
| from constants import HIDDEN_VALUE | from constants import HIDDEN_VALUE | ||||
| from core.entities import DEFAULT_PLUGIN_ID | |||||
| from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity | from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity | ||||
| from core.entities.provider_entities import ( | from core.entities.provider_entities import ( | ||||
| CustomConfiguration, | CustomConfiguration, | ||||
| ) | ) | ||||
| from core.helper import encrypter | from core.helper import encrypter | ||||
| from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | ||||
| from core.model_runtime.entities.model_entities import FetchFrom, ModelType | |||||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType | |||||
| from core.model_runtime.entities.provider_entities import ( | from core.model_runtime.entities.provider_entities import ( | ||||
| ConfigurateMethod, | ConfigurateMethod, | ||||
| CredentialFormSchema, | CredentialFormSchema, | ||||
| FormType, | FormType, | ||||
| ProviderEntity, | ProviderEntity, | ||||
| ) | ) | ||||
| from core.model_runtime.model_providers import model_provider_factory | |||||
| from core.model_runtime.model_providers.__base.ai_model import AIModel | from core.model_runtime.model_providers.__base.ai_model import AIModel | ||||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | |||||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.provider import ( | from models.provider import ( | ||||
| LoadBalancingModelConfig, | LoadBalancingModelConfig, | ||||
| continue | continue | ||||
| restrict_models = quota_configuration.restrict_models | restrict_models = quota_configuration.restrict_models | ||||
| if self.system_configuration.credentials is None: | |||||
| return None | |||||
| copy_credentials = self.system_configuration.credentials.copy() | |||||
| copy_credentials = ( | |||||
| self.system_configuration.credentials.copy() if self.system_configuration.credentials else {} | |||||
| ) | |||||
| if restrict_models: | if restrict_models: | ||||
| for restrict_model in restrict_models: | for restrict_model in restrict_models: | ||||
| if ( | if ( | ||||
| if current_quota_configuration is None: | if current_quota_configuration is None: | ||||
| return None | return None | ||||
| if not current_quota_configuration: | |||||
| return SystemConfigurationStatus.UNSUPPORTED | |||||
| return ( | return ( | ||||
| SystemConfigurationStatus.ACTIVE | SystemConfigurationStatus.ACTIVE | ||||
| if current_quota_configuration.is_valid | if current_quota_configuration.is_valid | ||||
| """ | """ | ||||
| return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 | return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 | ||||
| def get_custom_credentials(self, obfuscated: bool = False): | |||||
| def get_custom_credentials(self, obfuscated: bool = False) -> dict | None: | |||||
| """ | """ | ||||
| Get custom credentials. | Get custom credentials. | ||||
| else [], | else [], | ||||
| ) | ) | ||||
| def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]: | |||||
| def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: | |||||
| """ | """ | ||||
| Validate custom credentials. | Validate custom credentials. | ||||
| :param credentials: provider credentials | :param credentials: provider credentials | ||||
| if value == HIDDEN_VALUE and key in original_credentials: | if value == HIDDEN_VALUE and key in original_credentials: | ||||
| credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | ||||
| model_provider_factory = ModelProviderFactory(self.tenant_id) | |||||
| credentials = model_provider_factory.provider_credentials_validate( | credentials = model_provider_factory.provider_credentials_validate( | ||||
| provider=self.provider.provider, credentials=credentials | provider=self.provider.provider, credentials=credentials | ||||
| ) | ) | ||||
| provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | ||||
| db.session.commit() | db.session.commit() | ||||
| else: | else: | ||||
| provider_record = Provider( | |||||
| tenant_id=self.tenant_id, | |||||
| provider_name=self.provider.provider, | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps(credentials), | |||||
| is_valid=True, | |||||
| ) | |||||
| provider_record = Provider() | |||||
| provider_record.tenant_id = self.tenant_id | |||||
| provider_record.provider_name = self.provider.provider | |||||
| provider_record.provider_type = ProviderType.CUSTOM.value | |||||
| provider_record.encrypted_config = json.dumps(credentials) | |||||
| provider_record.is_valid = True | |||||
| db.session.add(provider_record) | db.session.add(provider_record) | ||||
| db.session.commit() | db.session.commit() | ||||
| def custom_model_credentials_validate( | def custom_model_credentials_validate( | ||||
| self, model_type: ModelType, model: str, credentials: dict | self, model_type: ModelType, model: str, credentials: dict | ||||
| ) -> tuple[Optional[ProviderModel], dict]: | |||||
| ) -> tuple[ProviderModel | None, dict]: | |||||
| """ | """ | ||||
| Validate custom model credentials. | Validate custom model credentials. | ||||
| if value == HIDDEN_VALUE and key in original_credentials: | if value == HIDDEN_VALUE and key in original_credentials: | ||||
| credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | ||||
| model_provider_factory = ModelProviderFactory(self.tenant_id) | |||||
| credentials = model_provider_factory.model_credentials_validate( | credentials = model_provider_factory.model_credentials_validate( | ||||
| provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials | provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials | ||||
| ) | ) | ||||
| provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | ||||
| db.session.commit() | db.session.commit() | ||||
| else: | else: | ||||
| provider_model_record = ProviderModel( | |||||
| tenant_id=self.tenant_id, | |||||
| provider_name=self.provider.provider, | |||||
| model_name=model, | |||||
| model_type=model_type.to_origin_model_type(), | |||||
| encrypted_config=json.dumps(credentials), | |||||
| is_valid=True, | |||||
| ) | |||||
| provider_model_record = ProviderModel() | |||||
| provider_model_record.tenant_id = self.tenant_id | |||||
| provider_model_record.provider_name = self.provider.provider | |||||
| provider_model_record.model_name = model | |||||
| provider_model_record.model_type = model_type.to_origin_model_type() | |||||
| provider_model_record.encrypted_config = json.dumps(credentials) | |||||
| provider_model_record.is_valid = True | |||||
| db.session.add(provider_model_record) | db.session.add(provider_model_record) | ||||
| db.session.commit() | db.session.commit() | ||||
| model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | ||||
| db.session.commit() | db.session.commit() | ||||
| else: | else: | ||||
| model_setting = ProviderModelSetting( | |||||
| tenant_id=self.tenant_id, | |||||
| provider_name=self.provider.provider, | |||||
| model_type=model_type.to_origin_model_type(), | |||||
| model_name=model, | |||||
| enabled=True, | |||||
| ) | |||||
| model_setting = ProviderModelSetting() | |||||
| model_setting.tenant_id = self.tenant_id | |||||
| model_setting.provider_name = self.provider.provider | |||||
| model_setting.model_type = model_type.to_origin_model_type() | |||||
| model_setting.model_name = model | |||||
| model_setting.enabled = True | |||||
| db.session.add(model_setting) | db.session.add(model_setting) | ||||
| db.session.commit() | db.session.commit() | ||||
| model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | ||||
| db.session.commit() | db.session.commit() | ||||
| else: | else: | ||||
| model_setting = ProviderModelSetting( | |||||
| tenant_id=self.tenant_id, | |||||
| provider_name=self.provider.provider, | |||||
| model_type=model_type.to_origin_model_type(), | |||||
| model_name=model, | |||||
| enabled=False, | |||||
| ) | |||||
| model_setting = ProviderModelSetting() | |||||
| model_setting.tenant_id = self.tenant_id | |||||
| model_setting.provider_name = self.provider.provider | |||||
| model_setting.model_type = model_type.to_origin_model_type() | |||||
| model_setting.model_name = model | |||||
| model_setting.enabled = False | |||||
| db.session.add(model_setting) | db.session.add(model_setting) | ||||
| db.session.commit() | db.session.commit() | ||||
| model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | ||||
| db.session.commit() | db.session.commit() | ||||
| else: | else: | ||||
| model_setting = ProviderModelSetting( | |||||
| tenant_id=self.tenant_id, | |||||
| provider_name=self.provider.provider, | |||||
| model_type=model_type.to_origin_model_type(), | |||||
| model_name=model, | |||||
| load_balancing_enabled=True, | |||||
| ) | |||||
| model_setting = ProviderModelSetting() | |||||
| model_setting.tenant_id = self.tenant_id | |||||
| model_setting.provider_name = self.provider.provider | |||||
| model_setting.model_type = model_type.to_origin_model_type() | |||||
| model_setting.model_name = model | |||||
| model_setting.load_balancing_enabled = True | |||||
| db.session.add(model_setting) | db.session.add(model_setting) | ||||
| db.session.commit() | db.session.commit() | ||||
| model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | ||||
| db.session.commit() | db.session.commit() | ||||
| else: | else: | ||||
| model_setting = ProviderModelSetting( | |||||
| tenant_id=self.tenant_id, | |||||
| provider_name=self.provider.provider, | |||||
| model_type=model_type.to_origin_model_type(), | |||||
| model_name=model, | |||||
| load_balancing_enabled=False, | |||||
| ) | |||||
| model_setting = ProviderModelSetting() | |||||
| model_setting.tenant_id = self.tenant_id | |||||
| model_setting.provider_name = self.provider.provider | |||||
| model_setting.model_type = model_type.to_origin_model_type() | |||||
| model_setting.model_name = model | |||||
| model_setting.load_balancing_enabled = False | |||||
| db.session.add(model_setting) | db.session.add(model_setting) | ||||
| db.session.commit() | db.session.commit() | ||||
| return model_setting | return model_setting | ||||
| def get_provider_instance(self) -> ModelProvider: | |||||
| """ | |||||
| Get provider instance. | |||||
| :return: | |||||
| """ | |||||
| return model_provider_factory.get_provider_instance(self.provider.provider) | |||||
| def get_model_type_instance(self, model_type: ModelType) -> AIModel: | def get_model_type_instance(self, model_type: ModelType) -> AIModel: | ||||
| """ | """ | ||||
| Get current model type instance. | Get current model type instance. | ||||
| :param model_type: model type | :param model_type: model type | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get provider instance | |||||
| provider_instance = self.get_provider_instance() | |||||
| model_provider_factory = ModelProviderFactory(self.tenant_id) | |||||
| # Get model instance of LLM | # Get model instance of LLM | ||||
| return provider_instance.get_model_instance(model_type) | |||||
| return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) | |||||
| def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None: | |||||
| """ | |||||
| Get model schema | |||||
| """ | |||||
| model_provider_factory = ModelProviderFactory(self.tenant_id) | |||||
| return model_provider_factory.get_model_schema( | |||||
| provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials | |||||
| ) | |||||
| def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: | def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: | ||||
| """ | """ | ||||
| if preferred_model_provider: | if preferred_model_provider: | ||||
| preferred_model_provider.preferred_provider_type = provider_type.value | preferred_model_provider.preferred_provider_type = provider_type.value | ||||
| else: | else: | ||||
| preferred_model_provider = TenantPreferredModelProvider( | |||||
| tenant_id=self.tenant_id, | |||||
| provider_name=self.provider.provider, | |||||
| preferred_provider_type=provider_type.value, | |||||
| ) | |||||
| preferred_model_provider = TenantPreferredModelProvider() | |||||
| preferred_model_provider.tenant_id = self.tenant_id | |||||
| preferred_model_provider.provider_name = self.provider.provider | |||||
| preferred_model_provider.preferred_provider_type = provider_type.value | |||||
| db.session.add(preferred_model_provider) | db.session.add(preferred_model_provider) | ||||
| db.session.commit() | db.session.commit() | ||||
| :param only_active: only active models | :param only_active: only active models | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| provider_instance = self.get_provider_instance() | |||||
| model_provider_factory = ModelProviderFactory(self.tenant_id) | |||||
| provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) | |||||
| model_types = [] | |||||
| model_types: list[ModelType] = [] | |||||
| if model_type: | if model_type: | ||||
| model_types.append(model_type) | model_types.append(model_type) | ||||
| else: | else: | ||||
| model_types = list(provider_instance.get_provider_schema().supported_model_types) | |||||
| model_types = list(provider_schema.supported_model_types) | |||||
| # Group model settings by model type and model | # Group model settings by model type and model | ||||
| model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) | model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) | ||||
| if self.using_provider_type == ProviderType.SYSTEM: | if self.using_provider_type == ProviderType.SYSTEM: | ||||
| provider_models = self._get_system_provider_models( | provider_models = self._get_system_provider_models( | ||||
| model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map | |||||
| model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map | |||||
| ) | ) | ||||
| else: | else: | ||||
| provider_models = self._get_custom_provider_models( | provider_models = self._get_custom_provider_models( | ||||
| model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map | |||||
| model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map | |||||
| ) | ) | ||||
| if only_active: | if only_active: | ||||
| def _get_system_provider_models( | def _get_system_provider_models( | ||||
| self, | self, | ||||
| model_types: list[ModelType], | |||||
| provider_instance: ModelProvider, | |||||
| model_types: Sequence[ModelType], | |||||
| provider_schema: ProviderEntity, | |||||
| model_setting_map: dict[ModelType, dict[str, ModelSettings]], | model_setting_map: dict[ModelType, dict[str, ModelSettings]], | ||||
| ) -> list[ModelWithProviderEntity]: | ) -> list[ModelWithProviderEntity]: | ||||
| """ | """ | ||||
| Get system provider models. | Get system provider models. | ||||
| :param model_types: model types | :param model_types: model types | ||||
| :param provider_instance: provider instance | |||||
| :param provider_schema: provider schema | |||||
| :param model_setting_map: model setting map | :param model_setting_map: model setting map | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| provider_models = [] | provider_models = [] | ||||
| for model_type in model_types: | for model_type in model_types: | ||||
| for m in provider_instance.models(model_type): | |||||
| for m in provider_schema.models: | |||||
| if m.model_type != model_type: | |||||
| continue | |||||
| status = ModelStatus.ACTIVE | status = ModelStatus.ACTIVE | ||||
| if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: | |||||
| if m.model in model_setting_map: | |||||
| model_setting = model_setting_map[m.model_type][m.model] | model_setting = model_setting_map[m.model_type][m.model] | ||||
| if model_setting.enabled is False: | if model_setting.enabled is False: | ||||
| status = ModelStatus.DISABLED | status = ModelStatus.DISABLED | ||||
| if self.provider.provider not in original_provider_configurate_methods: | if self.provider.provider not in original_provider_configurate_methods: | ||||
| original_provider_configurate_methods[self.provider.provider] = [] | original_provider_configurate_methods[self.provider.provider] = [] | ||||
| for configurate_method in provider_instance.get_provider_schema().configurate_methods: | |||||
| for configurate_method in provider_schema.configurate_methods: | |||||
| original_provider_configurate_methods[self.provider.provider].append(configurate_method) | original_provider_configurate_methods[self.provider.provider].append(configurate_method) | ||||
| should_use_custom_model = False | should_use_custom_model = False | ||||
| ]: | ]: | ||||
| # only customizable model | # only customizable model | ||||
| for restrict_model in restrict_models: | for restrict_model in restrict_models: | ||||
| if self.system_configuration.credentials is not None: | |||||
| copy_credentials = self.system_configuration.credentials.copy() | |||||
| if restrict_model.base_model_name: | |||||
| copy_credentials["base_model_name"] = restrict_model.base_model_name | |||||
| try: | |||||
| custom_model_schema = provider_instance.get_model_instance( | |||||
| restrict_model.model_type | |||||
| ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) | |||||
| except Exception as ex: | |||||
| logger.warning(f"get custom model schema failed, {ex}") | |||||
| continue | |||||
| copy_credentials = ( | |||||
| self.system_configuration.credentials.copy() | |||||
| if self.system_configuration.credentials | |||||
| else {} | |||||
| ) | |||||
| if restrict_model.base_model_name: | |||||
| copy_credentials["base_model_name"] = restrict_model.base_model_name | |||||
| try: | |||||
| custom_model_schema = self.get_model_schema( | |||||
| model_type=restrict_model.model_type, | |||||
| model=restrict_model.model, | |||||
| credentials=copy_credentials, | |||||
| ) | |||||
| except Exception as ex: | |||||
| logger.warning(f"get custom model schema failed, {ex}") | |||||
| if not custom_model_schema: | if not custom_model_schema: | ||||
| continue | continue | ||||
| def _get_custom_provider_models( | def _get_custom_provider_models( | ||||
| self, | self, | ||||
| model_types: list[ModelType], | |||||
| provider_instance: ModelProvider, | |||||
| model_types: Sequence[ModelType], | |||||
| provider_schema: ProviderEntity, | |||||
| model_setting_map: dict[ModelType, dict[str, ModelSettings]], | model_setting_map: dict[ModelType, dict[str, ModelSettings]], | ||||
| ) -> list[ModelWithProviderEntity]: | ) -> list[ModelWithProviderEntity]: | ||||
| """ | """ | ||||
| Get custom provider models. | Get custom provider models. | ||||
| :param model_types: model types | :param model_types: model types | ||||
| :param provider_instance: provider instance | |||||
| :param provider_schema: provider schema | |||||
| :param model_setting_map: model setting map | :param model_setting_map: model setting map | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| if model_type not in self.provider.supported_model_types: | if model_type not in self.provider.supported_model_types: | ||||
| continue | continue | ||||
| models = provider_instance.models(model_type) | |||||
| for m in models: | |||||
| for m in provider_schema.models: | |||||
| if m.model_type != model_type: | |||||
| continue | |||||
| status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE | status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE | ||||
| load_balancing_enabled = False | load_balancing_enabled = False | ||||
| if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: | if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: | ||||
| continue | continue | ||||
| try: | try: | ||||
| custom_model_schema = provider_instance.get_model_instance( | |||||
| model_configuration.model_type | |||||
| ).get_customizable_model_schema_from_credentials( | |||||
| model_configuration.model, model_configuration.credentials | |||||
| custom_model_schema = self.get_model_schema( | |||||
| model_type=model_configuration.model_type, | |||||
| model=model_configuration.model, | |||||
| credentials=model_configuration.credentials, | |||||
| ) | ) | ||||
| except Exception as ex: | except Exception as ex: | ||||
| logger.warning(f"get custom model schema failed, {ex}") | logger.warning(f"get custom model schema failed, {ex}") | ||||
| label=custom_model_schema.label, | label=custom_model_schema.label, | ||||
| model_type=custom_model_schema.model_type, | model_type=custom_model_schema.model_type, | ||||
| features=custom_model_schema.features, | features=custom_model_schema.features, | ||||
| fetch_from=custom_model_schema.fetch_from, | |||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||||
| model_properties=custom_model_schema.model_properties, | model_properties=custom_model_schema.model_properties, | ||||
| deprecated=custom_model_schema.deprecated, | deprecated=custom_model_schema.deprecated, | ||||
| provider=SimpleModelProviderEntity(self.provider), | provider=SimpleModelProviderEntity(self.provider), | ||||
| return list(self.values()) | return list(self.values()) | ||||
| def __getitem__(self, key): | def __getitem__(self, key): | ||||
| if "/" not in key: | |||||
| key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" | |||||
| return self.configurations[key] | return self.configurations[key] | ||||
| def __setitem__(self, key, value): | def __setitem__(self, key, value): | ||||
| def values(self) -> Iterator[ProviderConfiguration]: | def values(self) -> Iterator[ProviderConfiguration]: | ||||
| return iter(self.configurations.values()) | return iter(self.configurations.values()) | ||||
| def get(self, key, default=None): | |||||
| return self.configurations.get(key, default) | |||||
| def get(self, key, default=None) -> ProviderConfiguration | None: | |||||
| if "/" not in key: | |||||
| key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" | |||||
| return self.configurations.get(key, default) # type: ignore | |||||
| class ProviderModelBundle(BaseModel): | class ProviderModelBundle(BaseModel): | ||||
| """ | """ | ||||
| configuration: ProviderConfiguration | configuration: ProviderConfiguration | ||||
| provider_instance: ModelProvider | |||||
| model_type_instance: AIModel | model_type_instance: AIModel | ||||
| # pydantic configs | # pydantic configs | 
| from enum import Enum | from enum import Enum | ||||
| from typing import Optional | |||||
| from typing import Optional, Union | |||||
| from pydantic import BaseModel, ConfigDict | |||||
| from pydantic import BaseModel, ConfigDict, Field | |||||
| from core.entities.parameter_entities import ( | |||||
| AppSelectorScope, | |||||
| CommonParameterType, | |||||
| ModelSelectorScope, | |||||
| ToolSelectorScope, | |||||
| ) | |||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from models.provider import ProviderQuotaType | |||||
| from core.tools.entities.common_entities import I18nObject | |||||
| class ProviderQuotaType(Enum): | |||||
| PAID = "paid" | |||||
| """hosted paid quota""" | |||||
| FREE = "free" | |||||
| """third-party free quota""" | |||||
| TRIAL = "trial" | |||||
| """hosted trial quota""" | |||||
| @staticmethod | |||||
| def value_of(value): | |||||
| for member in ProviderQuotaType: | |||||
| if member.value == value: | |||||
| return member | |||||
| raise ValueError(f"No matching enum found for value '{value}'") | |||||
| class QuotaUnit(Enum): | class QuotaUnit(Enum): | ||||
| # pydantic configs | # pydantic configs | ||||
| model_config = ConfigDict(protected_namespaces=()) | model_config = ConfigDict(protected_namespaces=()) | ||||
| class BasicProviderConfig(BaseModel): | |||||
| """ | |||||
| Base model class for common provider settings like credentials | |||||
| """ | |||||
| class Type(Enum): | |||||
| SECRET_INPUT = CommonParameterType.SECRET_INPUT.value | |||||
| TEXT_INPUT = CommonParameterType.TEXT_INPUT.value | |||||
| SELECT = CommonParameterType.SELECT.value | |||||
| BOOLEAN = CommonParameterType.BOOLEAN.value | |||||
| APP_SELECTOR = CommonParameterType.APP_SELECTOR.value | |||||
| MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value | |||||
| @classmethod | |||||
| def value_of(cls, value: str) -> "ProviderConfig.Type": | |||||
| """ | |||||
| Get value of given mode. | |||||
| :param value: mode value | |||||
| :return: mode | |||||
| """ | |||||
| for mode in cls: | |||||
| if mode.value == value: | |||||
| return mode | |||||
| raise ValueError(f"invalid mode value {value}") | |||||
| type: Type = Field(..., description="The type of the credentials") | |||||
| name: str = Field(..., description="The name of the credentials") | |||||
| class ProviderConfig(BasicProviderConfig): | |||||
| """ | |||||
| Model class for common provider settings like credentials | |||||
| """ | |||||
| class Option(BaseModel): | |||||
| value: str = Field(..., description="The value of the option") | |||||
| label: I18nObject = Field(..., description="The label of the option") | |||||
| scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None | |||||
| required: bool = False | |||||
| default: Optional[Union[int, str]] = None | |||||
| options: Optional[list[Option]] = None | |||||
| label: Optional[I18nObject] = None | |||||
| help: Optional[I18nObject] = None | |||||
| url: Optional[str] = None | |||||
| placeholder: Optional[I18nObject] = None | |||||
| def to_basic_provider_config(self) -> BasicProviderConfig: | |||||
| return BasicProviderConfig(type=self.type, name=self.name) | 
| return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" | return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" | ||||
| def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: | |||||
| url = f"{dify_config.FILES_URL}/files/upload/for-plugin" | |||||
| if user_id is None: | |||||
| user_id = "DEFAULT-USER" | |||||
| timestamp = str(int(time.time())) | |||||
| nonce = os.urandom(16).hex() | |||||
| key = dify_config.SECRET_KEY.encode() | |||||
| msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" | |||||
| sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() | |||||
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |||||
| return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}" | |||||
| def verify_plugin_file_signature( | |||||
| *, filename: str, mimetype: str, tenant_id: str, user_id: str | None, timestamp: str, nonce: str, sign: str | |||||
| ) -> bool: | |||||
| if user_id is None: | |||||
| user_id = "DEFAULT-USER" | |||||
| data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" | |||||
| secret_key = dify_config.SECRET_KEY.encode() | |||||
| recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||||
| recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() | |||||
| # verify signature | |||||
| if sign != recalculated_encoded_sign: | |||||
| return False | |||||
| current_time = int(time.time()) | |||||
| return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT | |||||
| def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: | def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: | ||||
| data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" | data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" | ||||
| secret_key = dify_config.SECRET_KEY.encode() | secret_key = dify_config.SECRET_KEY.encode() | 
| from collections.abc import Mapping, Sequence | from collections.abc import Mapping, Sequence | ||||
| from typing import Optional | |||||
| from typing import Any, Optional | |||||
| from pydantic import BaseModel, Field, model_validator | from pydantic import BaseModel, Field, model_validator | ||||
| tool_file_id=self.related_id, extension=self.extension | tool_file_id=self.related_id, extension=self.extension | ||||
| ) | ) | ||||
| def to_plugin_parameter(self) -> dict[str, Any]: | |||||
| return { | |||||
| "dify_model_identity": FILE_MODEL_IDENTITY, | |||||
| "mime_type": self.mime_type, | |||||
| "filename": self.filename, | |||||
| "extension": self.extension, | |||||
| "size": self.size, | |||||
| "type": self.type, | |||||
| "url": self.generate_url(), | |||||
| } | |||||
| @model_validator(mode="after") | @model_validator(mode="after") | ||||
| def validate_after(self): | def validate_after(self): | ||||
| match self.transfer_method: | match self.transfer_method: | 
| import base64 | |||||
| import logging | |||||
| import time | |||||
| from typing import Optional | |||||
| from configs import dify_config | |||||
| from core.helper.url_signer import UrlSigner | |||||
| from extensions.ext_storage import storage | |||||
| IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] | |||||
| IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) | |||||
| class UploadFileParser: | |||||
| @classmethod | |||||
| def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: | |||||
| if not upload_file: | |||||
| return None | |||||
| if upload_file.extension not in IMAGE_EXTENSIONS: | |||||
| return None | |||||
| if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url: | |||||
| return cls.get_signed_temp_image_url(upload_file.id) | |||||
| else: | |||||
| # get image file base64 | |||||
| try: | |||||
| data = storage.load(upload_file.key) | |||||
| except FileNotFoundError: | |||||
| logging.exception(f"File not found: {upload_file.key}") | |||||
| return None | |||||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||||
| return f"data:{upload_file.mime_type};base64,{encoded_string}" | |||||
| @classmethod | |||||
| def get_signed_temp_image_url(cls, upload_file_id) -> str: | |||||
| """ | |||||
| get signed url from upload file | |||||
| :param upload_file: UploadFile object | |||||
| :return: | |||||
| """ | |||||
| base_url = dify_config.FILES_URL | |||||
| image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" | |||||
| return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview") | |||||
| @classmethod | |||||
| def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: | |||||
| """ | |||||
| verify signature | |||||
| :param upload_file_id: file id | |||||
| :param timestamp: timestamp | |||||
| :param nonce: nonce | |||||
| :param sign: signature | |||||
| :return: | |||||
| """ | |||||
| result = UrlSigner.verify( | |||||
| sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview" | |||||
| ) | |||||
| # verify signature | |||||
| if not result: | |||||
| return False | |||||
| current_time = int(time.time()) | |||||
| return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT | 
| from core.helper import ssrf_proxy | |||||
| def download_with_size_limit(url, max_download_size: int, **kwargs): | |||||
| response = ssrf_proxy.get(url, follow_redirects=True, **kwargs) | |||||
| if response.status_code == 404: | |||||
| raise ValueError("file not found") | |||||
| total_size = 0 | |||||
| chunks = [] | |||||
| for chunk in response.iter_bytes(): | |||||
| total_size += len(chunk) | |||||
| if total_size > max_download_size: | |||||
| raise ValueError("Max file size reached") | |||||
| chunks.append(chunk) | |||||
| content = b"".join(chunks) | |||||
| return content | 
| from collections.abc import Sequence | |||||
| import requests | |||||
| from yarl import URL | |||||
| from configs import dify_config | |||||
| from core.helper.download import download_with_size_limit | |||||
| from core.plugin.entities.marketplace import MarketplacePluginDeclaration | |||||
| def get_plugin_pkg_url(plugin_unique_identifier: str): | |||||
| return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query( | |||||
| unique_identifier=plugin_unique_identifier | |||||
| ) | |||||
| def download_plugin_pkg(plugin_unique_identifier: str): | |||||
| url = str(get_plugin_pkg_url(plugin_unique_identifier)) | |||||
| return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE) | |||||
| def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]: | |||||
| if len(plugin_ids) == 0: | |||||
| return [] | |||||
| url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch") | |||||
| response = requests.post(url, json={"plugin_ids": plugin_ids}) | |||||
| response.raise_for_status() | |||||
| return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] | |||||
| def record_install_plugin_event(plugin_unique_identifier: str): | |||||
| url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count") | |||||
| response = requests.post(url, json={"unique_identifier": plugin_unique_identifier}) | |||||
| response.raise_for_status() | 
| import logging | import logging | ||||
| import random | import random | ||||
| from typing import cast | |||||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | ||||
| from core.entities import DEFAULT_PLUGIN_ID | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.model_runtime.errors.invoke import InvokeBadRequestError | from core.model_runtime.errors.invoke import InvokeBadRequestError | ||||
| from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel | |||||
| from core.model_runtime.model_providers.__base.moderation_model import ModerationModel | |||||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||||
| from extensions.ext_hosting_provider import hosting_configuration | from extensions.ext_hosting_provider import hosting_configuration | ||||
| from models.provider import ProviderType | from models.provider import ProviderType | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: | |||||
| def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: | |||||
| moderation_config = hosting_configuration.moderation_config | moderation_config = hosting_configuration.moderation_config | ||||
| openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai" | |||||
| if ( | if ( | ||||
| moderation_config | moderation_config | ||||
| and moderation_config.enabled is True | and moderation_config.enabled is True | ||||
| and "openai" in hosting_configuration.provider_map | |||||
| and hosting_configuration.provider_map["openai"].enabled is True | |||||
| and openai_provider_name in hosting_configuration.provider_map | |||||
| and hosting_configuration.provider_map[openai_provider_name].enabled is True | |||||
| ): | ): | ||||
| using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type | using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type | ||||
| provider_name = model_config.provider | provider_name = model_config.provider | ||||
| if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: | if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: | ||||
| hosting_openai_config = hosting_configuration.provider_map["openai"] | |||||
| assert hosting_openai_config is not None | |||||
| hosting_openai_config = hosting_configuration.provider_map[openai_provider_name] | |||||
| if hosting_openai_config.credentials is None: | |||||
| return False | |||||
| # 2000 text per chunk | # 2000 text per chunk | ||||
| length = 2000 | length = 2000 | ||||
| text_chunk = random.choice(text_chunks) | text_chunk = random.choice(text_chunks) | ||||
| try: | try: | ||||
| model_type_instance = OpenAIModerationModel() | |||||
| # FIXME, for type hint using assert or raise ValueError is better here? | |||||
| model_provider_factory = ModelProviderFactory(tenant_id) | |||||
| # Get model instance of LLM | |||||
| model_type_instance = model_provider_factory.get_model_type_instance( | |||||
| provider=openai_provider_name, model_type=ModelType.MODERATION | |||||
| ) | |||||
| model_type_instance = cast(ModerationModel, model_type_instance) | |||||
| moderation_result = model_type_instance.invoke( | moderation_result = model_type_instance.invoke( | ||||
| model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk | |||||
| model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk | |||||
| ) | ) | ||||
| if moderation_result is True: | if moderation_result is True: | ||||
| return True | return True | ||||
| except Exception as ex: | |||||
| except Exception: | |||||
| logger.exception(f"Fails to check moderation, provider_name: {provider_name}") | logger.exception(f"Fails to check moderation, provider_name: {provider_name}") | ||||
| raise InvokeBadRequestError("Rate limit exceeded, please try again later.") | raise InvokeBadRequestError("Rate limit exceeded, please try again later.") | ||||
| ) | ) | ||||
| retries = 0 | retries = 0 | ||||
| stream = kwargs.pop("stream", False) | |||||
| while retries <= max_retries: | while retries <= max_retries: | ||||
| try: | try: | ||||
| if dify_config.SSRF_PROXY_ALL_URL: | if dify_config.SSRF_PROXY_ALL_URL: | 
| class ToolProviderCredentialsCacheType(Enum): | class ToolProviderCredentialsCacheType(Enum): | ||||
| PROVIDER = "tool_provider" | PROVIDER = "tool_provider" | ||||
| ENDPOINT = "endpoint" | |||||
| class ToolProviderCredentialsCache: | class ToolProviderCredentialsCache: | 
| import base64 | |||||
| import hashlib | |||||
| import hmac | |||||
| import os | |||||
| import time | |||||
| from pydantic import BaseModel, Field | |||||
| from configs import dify_config | |||||
| class SignedUrlParams(BaseModel): | |||||
| sign_key: str = Field(..., description="The sign key") | |||||
| timestamp: str = Field(..., description="Timestamp") | |||||
| nonce: str = Field(..., description="Nonce") | |||||
| sign: str = Field(..., description="Signature") | |||||
| class UrlSigner: | |||||
| @classmethod | |||||
| def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str: | |||||
| signed_url_params = cls.get_signed_url_params(sign_key, prefix) | |||||
| return ( | |||||
| f"{url}?timestamp={signed_url_params.timestamp}" | |||||
| f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}" | |||||
| ) | |||||
| @classmethod | |||||
| def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams: | |||||
| timestamp = str(int(time.time())) | |||||
| nonce = os.urandom(16).hex() | |||||
| sign = cls._sign(sign_key, timestamp, nonce, prefix) | |||||
| return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign) | |||||
| @classmethod | |||||
| def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool: | |||||
| recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix) | |||||
| return sign == recalculated_sign | |||||
| @classmethod | |||||
| def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str: | |||||
| if not dify_config.SECRET_KEY: | |||||
| raise Exception("SECRET_KEY is not set") | |||||
| data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}" | |||||
| secret_key = dify_config.SECRET_KEY.encode() | |||||
| sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||||
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |||||
| return encoded_sign | 
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.entities.provider_entities import QuotaUnit, RestrictModel | |||||
| from core.entities import DEFAULT_PLUGIN_ID | |||||
| from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel | |||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from models.provider import ProviderQuotaType | |||||
| class HostingQuota(BaseModel): | class HostingQuota(BaseModel): | ||||
| if dify_config.EDITION != "CLOUD": | if dify_config.EDITION != "CLOUD": | ||||
| return | return | ||||
| self.provider_map["azure_openai"] = self.init_azure_openai() | |||||
| self.provider_map["openai"] = self.init_openai() | |||||
| self.provider_map["anthropic"] = self.init_anthropic() | |||||
| self.provider_map["minimax"] = self.init_minimax() | |||||
| self.provider_map["spark"] = self.init_spark() | |||||
| self.provider_map["zhipuai"] = self.init_zhipuai() | |||||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai() | |||||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai() | |||||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic() | |||||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax() | |||||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark() | |||||
| self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai() | |||||
| self.moderation_config = self.init_moderation_config() | self.moderation_config = self.init_moderation_config() | ||||
| @staticmethod | @staticmethod | ||||
| def init_moderation_config() -> HostedModerationConfig: | def init_moderation_config() -> HostedModerationConfig: | ||||
| if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS: | if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS: | ||||
| return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(",")) | |||||
| providers = dify_config.HOSTED_MODERATION_PROVIDERS.split(",") | |||||
| hosted_providers = [] | |||||
| for provider in providers: | |||||
| if "/" not in provider: | |||||
| provider = f"{DEFAULT_PLUGIN_ID}/{provider}/{provider}" | |||||
| hosted_providers.append(provider) | |||||
| return HostedModerationConfig(enabled=True, providers=hosted_providers) | |||||
| return HostedModerationConfig(enabled=False) | return HostedModerationConfig(enabled=False) | ||||
| FixedRecursiveCharacterTextSplitter, | FixedRecursiveCharacterTextSplitter, | ||||
| ) | ) | ||||
| from core.rag.splitter.text_splitter import TextSplitter | from core.rag.splitter.text_splitter import TextSplitter | ||||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | |||||
| from core.tools.utils.rag_web_reader import get_image_upload_file_ids | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| tokens = 0 | tokens = 0 | ||||
| if embedding_model_instance: | if embedding_model_instance: | ||||
| tokens += sum( | |||||
| embedding_model_instance.get_text_embedding_num_tokens([document.page_content]) | |||||
| for document in chunk_documents | |||||
| ) | |||||
| page_content_list = [document.page_content for document in chunk_documents] | |||||
| tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list)) | |||||
| # load index | # load index | ||||
| index_processor.load(dataset, chunk_documents, with_keywords=False) | index_processor.load(dataset, chunk_documents, with_keywords=False) | 
| response = cast( | response = cast( | ||||
| LLMResult, | LLMResult, | ||||
| model_instance.invoke_llm( | model_instance.invoke_llm( | ||||
| prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False | |||||
| prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False | |||||
| ), | ), | ||||
| ) | ) | ||||
| answer = cast(str, response.message.content) | answer = cast(str, response.message.content) | ||||
| response = cast( | response = cast( | ||||
| LLMResult, | LLMResult, | ||||
| model_instance.invoke_llm( | model_instance.invoke_llm( | ||||
| prompt_messages=prompt_messages, | |||||
| prompt_messages=list(prompt_messages), | |||||
| model_parameters={"max_tokens": 256, "temperature": 0}, | model_parameters={"max_tokens": 256, "temperature": 0}, | ||||
| stream=False, | stream=False, | ||||
| ), | ), | ||||
| questions = output_parser.parse(cast(str, response.message.content)) | questions = output_parser.parse(cast(str, response.message.content)) | ||||
| except InvokeError: | except InvokeError: | ||||
| questions = [] | questions = [] | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| logging.exception("Failed to generate suggested questions after answer") | logging.exception("Failed to generate suggested questions after answer") | ||||
| questions = [] | questions = [] | ||||
| response = cast( | response = cast( | ||||
| LLMResult, | LLMResult, | ||||
| model_instance.invoke_llm( | model_instance.invoke_llm( | ||||
| prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False | |||||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||||
| ), | ), | ||||
| ) | ) | ||||
| prompt_content = cast( | prompt_content = cast( | ||||
| LLMResult, | LLMResult, | ||||
| model_instance.invoke_llm( | model_instance.invoke_llm( | ||||
| prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False | |||||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||||
| ), | ), | ||||
| ) | ) | ||||
| except InvokeError as e: | except InvokeError as e: | ||||
| parameter_content = cast( | parameter_content = cast( | ||||
| LLMResult, | LLMResult, | ||||
| model_instance.invoke_llm( | model_instance.invoke_llm( | ||||
| prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False | |||||
| prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False | |||||
| ), | ), | ||||
| ) | ) | ||||
| rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) | rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) | ||||
| statement_content = cast( | statement_content = cast( | ||||
| LLMResult, | LLMResult, | ||||
| model_instance.invoke_llm( | model_instance.invoke_llm( | ||||
| prompt_messages=statement_messages, model_parameters=model_parameters, stream=False | |||||
| prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False | |||||
| ), | ), | ||||
| ) | ) | ||||
| rule_config["opening_statement"] = cast(str, statement_content.message.content) | rule_config["opening_statement"] = cast(str, statement_content.message.content) | ||||
| response = cast( | response = cast( | ||||
| LLMResult, | LLMResult, | ||||
| model_instance.invoke_llm( | model_instance.invoke_llm( | ||||
| prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False | |||||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||||
| ), | ), | ||||
| ) | ) | ||||
| import logging | import logging | ||||
| from collections.abc import Callable, Generator, Iterable, Sequence | from collections.abc import Callable, Generator, Iterable, Sequence | ||||
| from typing import IO, Any, Optional, Union, cast | |||||
| from typing import IO, Any, Literal, Optional, Union, cast, overload | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.entities.embedding_type import EmbeddingInputType | from core.entities.embedding_type import EmbeddingInputType | ||||
| return None | return None | ||||
| @overload | |||||
| def invoke_llm( | |||||
| self, | |||||
| prompt_messages: list[PromptMessage], | |||||
| model_parameters: Optional[dict] = None, | |||||
| tools: Sequence[PromptMessageTool] | None = None, | |||||
| stop: Optional[list[str]] = None, | |||||
| stream: Literal[True] = True, | |||||
| user: Optional[str] = None, | |||||
| callbacks: Optional[list[Callback]] = None, | |||||
| ) -> Generator: ... | |||||
| @overload | |||||
| def invoke_llm( | |||||
| self, | |||||
| prompt_messages: list[PromptMessage], | |||||
| model_parameters: Optional[dict] = None, | |||||
| tools: Sequence[PromptMessageTool] | None = None, | |||||
| stop: Optional[list[str]] = None, | |||||
| stream: Literal[False] = False, | |||||
| user: Optional[str] = None, | |||||
| callbacks: Optional[list[Callback]] = None, | |||||
| ) -> LLMResult: ... | |||||
| @overload | |||||
| def invoke_llm( | |||||
| self, | |||||
| prompt_messages: list[PromptMessage], | |||||
| model_parameters: Optional[dict] = None, | |||||
| tools: Sequence[PromptMessageTool] | None = None, | |||||
| stop: Optional[list[str]] = None, | |||||
| stream: bool = True, | |||||
| user: Optional[str] = None, | |||||
| callbacks: Optional[list[Callback]] = None, | |||||
| ) -> Union[LLMResult, Generator]: ... | |||||
| def invoke_llm( | def invoke_llm( | ||||
| self, | self, | ||||
| prompt_messages: Sequence[PromptMessage], | prompt_messages: Sequence[PromptMessage], | ||||
| ), | ), | ||||
| ) | ) | ||||
| def get_text_embedding_num_tokens(self, texts: list[str]) -> int: | |||||
| def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: | |||||
| """ | """ | ||||
| Get number of tokens for text embedding | Get number of tokens for text embedding | ||||
| self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) | self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) | ||||
| return cast( | return cast( | ||||
| int, | |||||
| list[int], | |||||
| self._round_robin_invoke( | self._round_robin_invoke( | ||||
| function=self.model_type_instance.get_num_tokens, | function=self.model_type_instance.get_num_tokens, | ||||
| model=self.model, | model=self.model, | ||||
| return ModelInstance(provider_model_bundle, model) | return ModelInstance(provider_model_bundle, model) | ||||
| def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: | |||||
| def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: | |||||
| """ | """ | ||||
| Return first provider and the first model in the provider | Return first provider and the first model in the provider | ||||
| :param tenant_id: tenant id | :param tenant_id: tenant id |