| @@ -8,6 +8,7 @@ on: | |||
| - "deploy/enterprise" | |||
| - "build/**" | |||
| - "release/e-*" | |||
| - "deploy/rag-dev" | |||
| tags: | |||
| - "*" | |||
| @@ -4,7 +4,7 @@ on: | |||
| workflow_run: | |||
| workflows: ["Build and Push API & Web"] | |||
| branches: | |||
| - "deploy/dev" | |||
| - "deploy/rag-dev" | |||
| types: | |||
| - completed | |||
| @@ -12,12 +12,13 @@ jobs: | |||
| deploy: | |||
| runs-on: ubuntu-latest | |||
| if: | | |||
| github.event.workflow_run.conclusion == 'success' | |||
| github.event.workflow_run.conclusion == 'success' && | |||
| github.event.workflow_run.head_branch == 'deploy/rag-dev' | |||
| steps: | |||
| - name: Deploy to server | |||
| uses: appleboy/ssh-action@v0.1.8 | |||
| with: | |||
| host: ${{ secrets.SSH_HOST }} | |||
| host: ${{ secrets.RAG_SSH_HOST }} | |||
| username: ${{ secrets.SSH_USER }} | |||
| key: ${{ secrets.SSH_PRIVATE_KEY }} | |||
| script: | | |||
| @@ -460,6 +460,16 @@ WORKFLOW_CALL_MAX_DEPTH=5 | |||
| WORKFLOW_PARALLEL_DEPTH_LIMIT=3 | |||
| MAX_VARIABLE_SIZE=204800 | |||
| # GraphEngine Worker Pool Configuration | |||
| # Minimum number of workers per GraphEngine instance (default: 1) | |||
| GRAPH_ENGINE_MIN_WORKERS=1 | |||
| # Maximum number of workers per GraphEngine instance (default: 10) | |||
| GRAPH_ENGINE_MAX_WORKERS=10 | |||
| # Queue depth threshold that triggers worker scale up (default: 3) | |||
| GRAPH_ENGINE_SCALE_UP_THRESHOLD=3 | |||
| # Seconds of idle time before scaling down workers (default: 5.0) | |||
| GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0 | |||
| # Workflow storage configuration | |||
| # Options: rdbms, hybrid | |||
| # rdbms: Use only the relational database (default) | |||
| @@ -0,0 +1,122 @@ | |||
| [importlinter] | |||
| root_packages = | |||
| core | |||
| configs | |||
| controllers | |||
| models | |||
| tasks | |||
| services | |||
| [importlinter:contract:workflow] | |||
| name = Workflow | |||
| type=layers | |||
| layers = | |||
| graph_engine | |||
| graph_events | |||
| graph | |||
| nodes | |||
| node_events | |||
| entities | |||
| containers = | |||
| core.workflow | |||
| ignore_imports = | |||
| core.workflow.nodes.base.node -> core.workflow.graph_events | |||
| core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events | |||
| core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine | |||
| core.workflow.nodes.iteration.iteration_node -> core.workflow.graph | |||
| core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels | |||
| core.workflow.nodes.loop.loop_node -> core.workflow.graph_events | |||
| core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine | |||
| core.workflow.nodes.loop.loop_node -> core.workflow.graph | |||
| core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels | |||
| core.workflow.nodes.node_factory -> core.workflow.graph | |||
| [importlinter:contract:rsc] | |||
| name = RSC | |||
| type = layers | |||
| layers = | |||
| graph_engine | |||
| response_coordinator | |||
| output_registry | |||
| containers = | |||
| core.workflow.graph_engine | |||
| [importlinter:contract:worker] | |||
| name = Worker | |||
| type = layers | |||
| layers = | |||
| graph_engine | |||
| worker | |||
| containers = | |||
| core.workflow.graph_engine | |||
| [importlinter:contract:graph-engine-architecture] | |||
| name = Graph Engine Architecture | |||
| type = layers | |||
| layers = | |||
| graph_engine | |||
| orchestration | |||
| command_processing | |||
| event_management | |||
| error_handling | |||
| graph_traversal | |||
| state_management | |||
| worker_management | |||
| domain | |||
| containers = | |||
| core.workflow.graph_engine | |||
| [importlinter:contract:domain-isolation] | |||
| name = Domain Model Isolation | |||
| type = forbidden | |||
| source_modules = | |||
| core.workflow.graph_engine.domain | |||
| forbidden_modules = | |||
| core.workflow.graph_engine.worker_management | |||
| core.workflow.graph_engine.command_channels | |||
| core.workflow.graph_engine.layers | |||
| core.workflow.graph_engine.protocols | |||
| [importlinter:contract:state-management-layers] | |||
| name = State Management Layers | |||
| type = layers | |||
| layers = | |||
| execution_tracker | |||
| node_state_manager | |||
| edge_state_manager | |||
| containers = | |||
| core.workflow.graph_engine.state_management | |||
| [importlinter:contract:worker-management-layers] | |||
| name = Worker Management Layers | |||
| type = layers | |||
| layers = | |||
| worker_pool | |||
| worker_factory | |||
| dynamic_scaler | |||
| activity_tracker | |||
| containers = | |||
| core.workflow.graph_engine.worker_management | |||
| [importlinter:contract:error-handling-strategies] | |||
| name = Error Handling Strategies | |||
| type = independence | |||
| modules = | |||
| core.workflow.graph_engine.error_handling.abort_strategy | |||
| core.workflow.graph_engine.error_handling.retry_strategy | |||
| core.workflow.graph_engine.error_handling.fail_branch_strategy | |||
| core.workflow.graph_engine.error_handling.default_value_strategy | |||
| [importlinter:contract:graph-traversal-components] | |||
| name = Graph Traversal Components | |||
| type = independence | |||
| modules = | |||
| core.workflow.graph_engine.graph_traversal.node_readiness | |||
| core.workflow.graph_engine.graph_traversal.skip_propagator | |||
| [importlinter:contract:command-channels] | |||
| name = Command Channels Independence | |||
| type = independence | |||
| modules = | |||
| core.workflow.graph_engine.command_channels.in_memory_channel | |||
| core.workflow.graph_engine.command_channels.redis_channel | |||
| @@ -1,4 +1,3 @@ | |||
| import os | |||
| import sys | |||
| @@ -17,20 +16,20 @@ else: | |||
| # It seems that JetBrains Python debugger does not work well with gevent, | |||
| # so we need to disable gevent in debug mode. | |||
| # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. | |||
| if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: | |||
| from gevent import monkey | |||
| # gevent | |||
| monkey.patch_all() | |||
| from grpc.experimental import gevent as grpc_gevent # type: ignore | |||
| # grpc gevent | |||
| grpc_gevent.init_gevent() | |||
| import psycogreen.gevent # type: ignore | |||
| psycogreen.gevent.patch_psycopg() | |||
| # if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: | |||
| # from gevent import monkey | |||
| # | |||
| # # gevent | |||
| # monkey.patch_all() | |||
| # | |||
| # from grpc.experimental import gevent as grpc_gevent # type: ignore | |||
| # | |||
| # # grpc gevent | |||
| # grpc_gevent.init_gevent() | |||
| # import psycogreen.gevent # type: ignore | |||
| # | |||
| # psycogreen.gevent.patch_psycopg() | |||
| from app_factory import create_app | |||
| @@ -13,11 +13,14 @@ from sqlalchemy.exc import SQLAlchemyError | |||
| from configs import dify_config | |||
| from constants.languages import languages | |||
| from core.plugin.entities.plugin import ToolProviderID | |||
| from core.helper import encrypter | |||
| from core.plugin.entities.plugin import PluginInstallationSource | |||
| from core.plugin.impl.plugin import PluginInstaller | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from core.rag.datasource.vdb.vector_type import VectorType | |||
| from core.rag.index_processor.constant.built_in_field import BuiltInField | |||
| from core.rag.models.document import Document | |||
| from core.tools.entities.tool_entities import CredentialType | |||
| from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params | |||
| from events.app_event import app_was_created | |||
| from extensions.ext_database import db | |||
| @@ -30,7 +33,10 @@ from models import Tenant | |||
| from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation | |||
| from models.oauth import DatasourceOauthParamConfig, DatasourceProvider | |||
| from models.provider import Provider, ProviderModel | |||
| from models.provider_ids import DatasourceProviderID, ToolProviderID | |||
| from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding | |||
| from models.tools import ToolOAuthSystemClient | |||
| from services.account_service import AccountService, RegisterService, TenantService | |||
| from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs | |||
| @@ -1354,3 +1360,250 @@ def cleanup_orphaned_draft_variables( | |||
| continue | |||
| logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps) | |||
| @click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.") | |||
| @click.option("--provider", prompt=True, help="Provider name") | |||
| @click.option("--client-params", prompt=True, help="Client Params") | |||
| def setup_datasource_oauth_client(provider, client_params): | |||
| """ | |||
| Setup datasource oauth client | |||
| """ | |||
| provider_id = DatasourceProviderID(provider) | |||
| provider_name = provider_id.provider_name | |||
| plugin_id = provider_id.plugin_id | |||
| try: | |||
| # json validate | |||
| click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) | |||
| client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) | |||
| click.echo(click.style("Client params validated successfully.", fg="green")) | |||
| except Exception as e: | |||
| click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) | |||
| return | |||
| click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow")) | |||
| deleted_count = ( | |||
| db.session.query(DatasourceOauthParamConfig) | |||
| .filter_by( | |||
| provider=provider_name, | |||
| plugin_id=plugin_id, | |||
| ) | |||
| .delete() | |||
| ) | |||
| if deleted_count > 0: | |||
| click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) | |||
| click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow")) | |||
| oauth_client = DatasourceOauthParamConfig( | |||
| provider=provider_name, | |||
| plugin_id=plugin_id, | |||
| system_credentials=client_params_dict, | |||
| ) | |||
| db.session.add(oauth_client) | |||
| db.session.commit() | |||
| click.echo(click.style(f"provider: {provider_name}", fg="green")) | |||
| click.echo(click.style(f"plugin_id: {plugin_id}", fg="green")) | |||
| click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green")) | |||
| click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green")) | |||
| @click.command("transform-datasource-credentials", help="Transform datasource credentials.") | |||
| def transform_datasource_credentials(): | |||
| """ | |||
| Transform datasource credentials | |||
| """ | |||
| try: | |||
| installer_manager = PluginInstaller() | |||
| plugin_migration = PluginMigration() | |||
| notion_plugin_id = "langgenius/notion_datasource" | |||
| firecrawl_plugin_id = "langgenius/firecrawl_datasource" | |||
| jina_plugin_id = "langgenius/jina_datasource" | |||
| notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) | |||
| firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) | |||
| jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) | |||
| oauth_credential_type = CredentialType.OAUTH2 | |||
| api_key_credential_type = CredentialType.API_KEY | |||
| # deal notion credentials | |||
| deal_notion_count = 0 | |||
| notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all() | |||
| notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {} | |||
| for credential in notion_credentials: | |||
| tenant_id = credential.tenant_id | |||
| if tenant_id not in notion_credentials_tenant_mapping: | |||
| notion_credentials_tenant_mapping[tenant_id] = [] | |||
| notion_credentials_tenant_mapping[tenant_id].append(credential) | |||
| for tenant_id, credentials in notion_credentials_tenant_mapping.items(): | |||
| # check notion plugin is installed | |||
| installed_plugins = installer_manager.list_plugins(tenant_id) | |||
| installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] | |||
| if notion_plugin_id not in installed_plugins_ids: | |||
| if notion_plugin_unique_identifier: | |||
| # install notion plugin | |||
| installer_manager.install_from_identifiers( | |||
| tenant_id, | |||
| [notion_plugin_unique_identifier], | |||
| PluginInstallationSource.Marketplace, | |||
| metas=[ | |||
| { | |||
| "plugin_unique_identifier": notion_plugin_unique_identifier, | |||
| } | |||
| ], | |||
| ) | |||
| auth_count = 0 | |||
| for credential in credentials: | |||
| auth_count += 1 | |||
| # get credential oauth params | |||
| access_token = credential.access_token | |||
| # notion info | |||
| notion_info = credential.source_info | |||
| workspace_id = notion_info.get("workspace_id") | |||
| workspace_name = notion_info.get("workspace_name") | |||
| workspace_icon = notion_info.get("workspace_icon") | |||
| new_credentials = { | |||
| "integration_secret": encrypter.encrypt_token(tenant_id, access_token), | |||
| "workspace_id": workspace_id, | |||
| "workspace_name": workspace_name, | |||
| "workspace_icon": workspace_icon, | |||
| } | |||
| datasource_provider = DatasourceProvider( | |||
| provider="notion", | |||
| tenant_id=tenant_id, | |||
| plugin_id=notion_plugin_id, | |||
| auth_type=oauth_credential_type.value, | |||
| encrypted_credentials=new_credentials, | |||
| name=f"Auth {auth_count}", | |||
| avatar_url=workspace_icon or "default", | |||
| is_default=False, | |||
| ) | |||
| db.session.add(datasource_provider) | |||
| deal_notion_count += 1 | |||
| db.session.commit() | |||
| # deal firecrawl credentials | |||
| deal_firecrawl_count = 0 | |||
| firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all() | |||
| firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} | |||
| for credential in firecrawl_credentials: | |||
| tenant_id = credential.tenant_id | |||
| if tenant_id not in firecrawl_credentials_tenant_mapping: | |||
| firecrawl_credentials_tenant_mapping[tenant_id] = [] | |||
| firecrawl_credentials_tenant_mapping[tenant_id].append(credential) | |||
| for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items(): | |||
| # check firecrawl plugin is installed | |||
| installed_plugins = installer_manager.list_plugins(tenant_id) | |||
| installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] | |||
| if firecrawl_plugin_id not in installed_plugins_ids: | |||
| if firecrawl_plugin_unique_identifier: | |||
| # install firecrawl plugin | |||
| installer_manager.install_from_identifiers( | |||
| tenant_id, | |||
| [firecrawl_plugin_unique_identifier], | |||
| PluginInstallationSource.Marketplace, | |||
| metas=[ | |||
| { | |||
| "plugin_unique_identifier": firecrawl_plugin_unique_identifier, | |||
| } | |||
| ], | |||
| ) | |||
| auth_count = 0 | |||
| for credential in credentials: | |||
| auth_count += 1 | |||
| # get credential api key | |||
| api_key = credential.credentials.get("config", {}).get("api_key") | |||
| base_url = credential.credentials.get("config", {}).get("base_url") | |||
| new_credentials = { | |||
| "firecrawl_api_key": api_key, | |||
| "base_url": base_url, | |||
| } | |||
| datasource_provider = DatasourceProvider( | |||
| provider="firecrawl", | |||
| tenant_id=tenant_id, | |||
| plugin_id=firecrawl_plugin_id, | |||
| auth_type=api_key_credential_type.value, | |||
| encrypted_credentials=new_credentials, | |||
| name=f"Auth {auth_count}", | |||
| avatar_url="default", | |||
| is_default=False, | |||
| ) | |||
| db.session.add(datasource_provider) | |||
| deal_firecrawl_count += 1 | |||
| db.session.commit() | |||
| # deal jina credentials | |||
| deal_jina_count = 0 | |||
| jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jina").all() | |||
| jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} | |||
| for credential in jina_credentials: | |||
| tenant_id = credential.tenant_id | |||
| if tenant_id not in jina_credentials_tenant_mapping: | |||
| jina_credentials_tenant_mapping[tenant_id] = [] | |||
| jina_credentials_tenant_mapping[tenant_id].append(credential) | |||
| for tenant_id, credentials in jina_credentials_tenant_mapping.items(): | |||
| # check jina plugin is installed | |||
| installed_plugins = installer_manager.list_plugins(tenant_id) | |||
| installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] | |||
| if jina_plugin_id not in installed_plugins_ids: | |||
| if jina_plugin_unique_identifier: | |||
| # install jina plugin | |||
| installer_manager.install_from_identifiers( | |||
| tenant_id, | |||
| [jina_plugin_unique_identifier], | |||
| PluginInstallationSource.Marketplace, | |||
| metas=[ | |||
| { | |||
| "plugin_unique_identifier": jina_plugin_unique_identifier, | |||
| } | |||
| ], | |||
| ) | |||
| auth_count = 0 | |||
| for credential in credentials: | |||
| auth_count += 1 | |||
| # get credential api key | |||
| api_key = credential.credentials.get("config", {}).get("api_key") | |||
| new_credentials = { | |||
| "integration_secret": api_key, | |||
| } | |||
| datasource_provider = DatasourceProvider( | |||
| provider="jina", | |||
| tenant_id=tenant_id, | |||
| plugin_id=jina_plugin_id, | |||
| auth_type=api_key_credential_type.value, | |||
| encrypted_credentials=new_credentials, | |||
| name=f"Auth {auth_count}", | |||
| avatar_url="default", | |||
| is_default=False, | |||
| ) | |||
| db.session.add(datasource_provider) | |||
| deal_jina_count += 1 | |||
| db.session.commit() | |||
| except Exception as e: | |||
| click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) | |||
| return | |||
| click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green")) | |||
| click.echo( | |||
| click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green") | |||
| ) | |||
| click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green")) | |||
| @click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.") | |||
| @click.option( | |||
| "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" | |||
| ) | |||
| @click.option( | |||
| "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" | |||
| ) | |||
| @click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100) | |||
| def install_rag_pipeline_plugins(input_file, output_file, workers): | |||
| """ | |||
| Install rag pipeline plugins | |||
| """ | |||
| click.echo(click.style("Installing rag pipeline plugins", fg="yellow")) | |||
| plugin_migration = PluginMigration() | |||
| plugin_migration.install_rag_pipeline_plugins( | |||
| input_file, | |||
| output_file, | |||
| workers, | |||
| ) | |||
| click.echo(click.style("Installing rag pipeline plugins successfully", fg="green")) | |||
| @@ -545,6 +545,28 @@ class WorkflowConfig(BaseSettings): | |||
| default=200 * 1024, | |||
| ) | |||
| # GraphEngine Worker Pool Configuration | |||
| GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field( | |||
| description="Minimum number of workers per GraphEngine instance", | |||
| default=1, | |||
| ) | |||
| GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field( | |||
| description="Maximum number of workers per GraphEngine instance", | |||
| default=10, | |||
| ) | |||
| GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field( | |||
| description="Queue depth threshold that triggers worker scale up", | |||
| default=3, | |||
| ) | |||
| GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field( | |||
| description="Seconds of idle time before scaling down workers", | |||
| default=5.0, | |||
| ge=0.1, | |||
| ) | |||
| class WorkflowNodeExecutionConfig(BaseSettings): | |||
| """ | |||
| @@ -222,11 +222,28 @@ class HostedFetchAppTemplateConfig(BaseSettings): | |||
| ) | |||
| class HostedFetchPipelineTemplateConfig(BaseSettings): | |||
| """ | |||
| Configuration for fetching pipeline templates | |||
| """ | |||
| HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field( | |||
| description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,", | |||
| default="database", | |||
| ) | |||
| HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field( | |||
| description="Domain for fetching remote pipeline templates", | |||
| default="https://tmpl.dify.ai", | |||
| ) | |||
| class HostedServiceConfig( | |||
| # place the configs in alphabet order | |||
| HostedAnthropicConfig, | |||
| HostedAzureOpenAiConfig, | |||
| HostedFetchAppTemplateConfig, | |||
| HostedFetchPipelineTemplateConfig, | |||
| HostedMinmaxConfig, | |||
| HostedOpenAiConfig, | |||
| HostedSparkConfig, | |||
| @@ -3,6 +3,7 @@ from threading import Lock | |||
| from typing import TYPE_CHECKING | |||
| from contexts.wrapper import RecyclableContextVar | |||
| from core.datasource.__base.datasource_provider import DatasourcePluginProviderController | |||
| if TYPE_CHECKING: | |||
| from core.model_runtime.entities.model_entities import AIModelEntity | |||
| @@ -33,3 +34,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont | |||
| plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( | |||
| ContextVar("plugin_model_schemas") | |||
| ) | |||
| datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( | |||
| RecyclableContextVar(ContextVar("datasource_plugin_providers")) | |||
| ) | |||
| datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( | |||
| ContextVar("datasource_plugin_providers_lock") | |||
| ) | |||
| @@ -43,7 +43,7 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm" | |||
| api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies") | |||
| # Import other controllers | |||
| from . import admin, apikey, extension, feature, ping, setup, version | |||
| from . import admin, apikey, extension, feature, ping, setup, spec, version | |||
| # Import app controllers | |||
| from .app import ( | |||
| @@ -86,6 +86,15 @@ from .datasets import ( | |||
| metadata, | |||
| website, | |||
| ) | |||
| from .datasets.rag_pipeline import ( | |||
| datasource_auth, | |||
| datasource_content_preview, | |||
| rag_pipeline, | |||
| rag_pipeline_datasets, | |||
| rag_pipeline_draft_variable, | |||
| rag_pipeline_import, | |||
| rag_pipeline_workflow, | |||
| ) | |||
| # Import explore controllers | |||
| from .explore import ( | |||
| @@ -16,7 +16,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc | |||
| from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider | |||
| from core.llm_generator.llm_generator import LLMGenerator | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from extensions.ext_database import db | |||
| from libs.login import login_required | |||
| from models import App | |||
| from services.workflow_service import WorkflowService | |||
| class RuleGenerateApi(Resource): | |||
| @@ -135,9 +138,6 @@ class InstructionGenerateApi(Resource): | |||
| try: | |||
| # Generate from nothing for a workflow node | |||
| if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "": | |||
| from models import App, db | |||
| from services.workflow_service import WorkflowService | |||
| app = db.session.query(App).where(App.id == args["flow_id"]).first() | |||
| if not app: | |||
| return {"error": f"app {args['flow_id']} not found"}, 400 | |||
| @@ -24,6 +24,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.file.models import File | |||
| from core.helper.trace_id_helper import get_external_trace_id | |||
| from core.workflow.graph_engine.manager import GraphEngineManager | |||
| from extensions.ext_database import db | |||
| from factories import file_factory, variable_factory | |||
| from fields.workflow_fields import workflow_fields, workflow_pagination_fields | |||
| @@ -413,7 +414,12 @@ class WorkflowTaskStopApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) | |||
| # Stop using both mechanisms for backward compatibility | |||
| # Legacy stop flag mechanism (without user check) | |||
| AppQueueManager.set_stop_flag_no_user_check(task_id) | |||
| # New graph engine command channel mechanism | |||
| GraphEngineManager.send_stop_command(task_id) | |||
| return {"result": "success"} | |||
| @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session | |||
| from controllers.console import api | |||
| from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus | |||
| from core.workflow.enums import WorkflowExecutionStatus | |||
| from extensions.ext_database import db | |||
| from fields.workflow_app_log_fields import workflow_app_log_pagination_fields | |||
| from libs.login import login_required | |||
| @@ -18,10 +18,11 @@ from core.variables.segment_group import SegmentGroup | |||
| from core.variables.segments import ArrayFileSegment, FileSegment, Segment | |||
| from core.variables.types import SegmentType | |||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | |||
| from extensions.ext_database import db | |||
| from factories.file_factory import build_from_mapping, build_from_mappings | |||
| from factories.variable_factory import build_segment_with_type | |||
| from libs.login import current_user, login_required | |||
| from models import App, AppMode, db | |||
| from models import App, AppMode | |||
| from models.account import Account | |||
| from models.workflow import WorkflowDraftVariable | |||
| from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService | |||
| @@ -1,4 +1,6 @@ | |||
| import json | |||
| from collections.abc import Generator | |||
| from typing import cast | |||
| from flask import request | |||
| from flask_login import current_user | |||
| @@ -9,6 +11,8 @@ from werkzeug.exceptions import NotFound | |||
| from controllers.console import api | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage | |||
| from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.extractor.notion_extractor import NotionExtractor | |||
| @@ -18,6 +22,7 @@ from libs.datetime_utils import naive_utc_now | |||
| from libs.login import login_required | |||
| from models import DataSourceOauthBinding, Document | |||
| from services.dataset_service import DatasetService, DocumentService | |||
| from services.datasource_provider_service import DatasourceProviderService | |||
| from tasks.document_indexing_sync_task import document_indexing_sync_task | |||
| @@ -112,6 +117,18 @@ class DataSourceNotionListApi(Resource): | |||
| @marshal_with(integrate_notion_info_list_fields) | |||
| def get(self): | |||
| dataset_id = request.args.get("dataset_id", default=None, type=str) | |||
| credential_id = request.args.get("credential_id", default=None, type=str) | |||
| if not credential_id: | |||
| raise ValueError("Credential id is required.") | |||
| datasource_provider_service = DatasourceProviderService() | |||
| credential = datasource_provider_service.get_datasource_credentials( | |||
| tenant_id=current_user.current_tenant_id, | |||
| credential_id=credential_id, | |||
| provider="notion_datasource", | |||
| plugin_id="langgenius/notion_datasource", | |||
| ) | |||
| if not credential: | |||
| raise NotFound("Credential not found.") | |||
| exist_page_ids = [] | |||
| with Session(db.engine) as session: | |||
| # import notion in the exist dataset | |||
| @@ -135,31 +152,49 @@ class DataSourceNotionListApi(Resource): | |||
| data_source_info = json.loads(document.data_source_info) | |||
| exist_page_ids.append(data_source_info["notion_page_id"]) | |||
| # get all authorized pages | |||
| data_source_bindings = session.scalars( | |||
| select(DataSourceOauthBinding).filter_by( | |||
| tenant_id=current_user.current_tenant_id, provider="notion", disabled=False | |||
| from core.datasource.datasource_manager import DatasourceManager | |||
| datasource_runtime = DatasourceManager.get_datasource_runtime( | |||
| provider_id="langgenius/notion_datasource/notion_datasource", | |||
| datasource_name="notion_datasource", | |||
| tenant_id=current_user.current_tenant_id, | |||
| datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, | |||
| ) | |||
| datasource_provider_service = DatasourceProviderService() | |||
| if credential: | |||
| datasource_runtime.runtime.credentials = credential | |||
| datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) | |||
| online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( | |||
| datasource_runtime.get_online_document_pages( | |||
| user_id=current_user.id, | |||
| datasource_parameters={}, | |||
| provider_type=datasource_runtime.datasource_provider_type(), | |||
| ) | |||
| ).all() | |||
| if not data_source_bindings: | |||
| return {"notion_info": []}, 200 | |||
| pre_import_info_list = [] | |||
| for data_source_binding in data_source_bindings: | |||
| source_info = data_source_binding.source_info | |||
| pages = source_info["pages"] | |||
| # Filter out already bound pages | |||
| for page in pages: | |||
| if page["page_id"] in exist_page_ids: | |||
| page["is_bound"] = True | |||
| else: | |||
| page["is_bound"] = False | |||
| pre_import_info = { | |||
| "workspace_name": source_info["workspace_name"], | |||
| "workspace_icon": source_info["workspace_icon"], | |||
| "workspace_id": source_info["workspace_id"], | |||
| "pages": pages, | |||
| } | |||
| pre_import_info_list.append(pre_import_info) | |||
| return {"notion_info": pre_import_info_list}, 200 | |||
| ) | |||
| try: | |||
| pages = [] | |||
| workspace_info = {} | |||
| for message in online_document_result: | |||
| result = message.result | |||
| for info in result: | |||
| workspace_info = { | |||
| "workspace_id": info.workspace_id, | |||
| "workspace_name": info.workspace_name, | |||
| "workspace_icon": info.workspace_icon, | |||
| } | |||
| for page in info.pages: | |||
| page_info = { | |||
| "page_id": page.page_id, | |||
| "page_name": page.page_name, | |||
| "type": page.type, | |||
| "parent_id": page.parent_id, | |||
| "is_bound": page.page_id in exist_page_ids, | |||
| "page_icon": page.page_icon, | |||
| } | |||
| pages.append(page_info) | |||
| except Exception as e: | |||
| raise e | |||
| return {"notion_info": {**workspace_info, "pages": pages}}, 200 | |||
| class DataSourceNotionApi(Resource): | |||
| @@ -167,27 +202,25 @@ class DataSourceNotionApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, workspace_id, page_id, page_type): | |||
| credential_id = request.args.get("credential_id", default=None, type=str) | |||
| if not credential_id: | |||
| raise ValueError("Credential id is required.") | |||
| datasource_provider_service = DatasourceProviderService() | |||
| credential = datasource_provider_service.get_datasource_credentials( | |||
| tenant_id=current_user.current_tenant_id, | |||
| credential_id=credential_id, | |||
| provider="notion_datasource", | |||
| plugin_id="langgenius/notion_datasource", | |||
| ) | |||
| workspace_id = str(workspace_id) | |||
| page_id = str(page_id) | |||
| with Session(db.engine) as session: | |||
| data_source_binding = session.execute( | |||
| select(DataSourceOauthBinding).where( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||
| ) | |||
| ) | |||
| ).scalar_one_or_none() | |||
| if not data_source_binding: | |||
| raise NotFound("Data source binding not found.") | |||
| extractor = NotionExtractor( | |||
| notion_workspace_id=workspace_id, | |||
| notion_obj_id=page_id, | |||
| notion_page_type=page_type, | |||
| notion_access_token=data_source_binding.access_token, | |||
| notion_access_token=credential.get("integration_secret"), | |||
| tenant_id=current_user.current_tenant_id, | |||
| ) | |||
| @@ -212,10 +245,12 @@ class DataSourceNotionApi(Resource): | |||
| extract_settings = [] | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info["workspace_id"] | |||
| credential_id = notion_info.get("credential_id") | |||
| for page in notion_info["pages"]: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| notion_info={ | |||
| "credential_id": credential_id, | |||
| "notion_workspace_id": workspace_id, | |||
| "notion_obj_id": page["page_id"], | |||
| "notion_page_type": page["type"], | |||
| @@ -19,7 +19,6 @@ from controllers.console.wraps import ( | |||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.provider_manager import ProviderManager | |||
| from core.rag.datasource.vdb.vector_type import VectorType | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| @@ -31,6 +30,7 @@ from fields.document_fields import document_status_fields | |||
| from libs.login import login_required | |||
| from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile | |||
| from models.dataset import DatasetPermissionEnum | |||
| from models.provider_ids import ModelProviderID | |||
| from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService | |||
| @@ -279,6 +279,15 @@ class DatasetApi(Resource): | |||
| location="json", | |||
| help="Invalid external knowledge api id.", | |||
| ) | |||
| parser.add_argument( | |||
| "icon_info", | |||
| type=dict, | |||
| required=False, | |||
| nullable=True, | |||
| location="json", | |||
| help="Invalid icon info.", | |||
| ) | |||
| args = parser.parse_args() | |||
| data = request.get_json() | |||
| @@ -429,10 +438,12 @@ class DatasetIndexingEstimateApi(Resource): | |||
| notion_info_list = args["info_list"]["notion_info_list"] | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info["workspace_id"] | |||
| credential_id = notion_info.get("credential_id") | |||
| for page in notion_info["pages"]: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| notion_info={ | |||
| "credential_id": credential_id, | |||
| "notion_workspace_id": workspace_id, | |||
| "notion_obj_id": page["page_id"], | |||
| "notion_page_type": page["type"], | |||
| @@ -1,3 +1,4 @@ | |||
| import json | |||
| import logging | |||
| from argparse import ArgumentTypeError | |||
| from typing import Literal, cast | |||
| @@ -51,6 +52,7 @@ from fields.document_fields import ( | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.login import login_required | |||
| from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile | |||
| from models.dataset import DocumentPipelineExecutionLog | |||
| from services.dataset_service import DatasetService, DocumentService | |||
| from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig | |||
| @@ -496,6 +498,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| notion_info={ | |||
| "credential_id": data_source_info["credential_id"], | |||
| "notion_workspace_id": data_source_info["notion_workspace_id"], | |||
| "notion_obj_id": data_source_info["notion_page_id"], | |||
| "notion_page_type": data_source_info["type"], | |||
| @@ -649,7 +652,7 @@ class DocumentApi(DocumentResource): | |||
| response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} | |||
| elif metadata == "without": | |||
| dataset_process_rules = DatasetService.get_process_rules(dataset_id) | |||
| document_process_rules = document.dataset_process_rule.to_dict() | |||
| document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} | |||
| data_source_info = document.data_source_detail_dict | |||
| response = { | |||
| "id": document.id, | |||
| @@ -1012,6 +1015,41 @@ class WebsiteDocumentSyncApi(DocumentResource): | |||
| return {"result": "success"}, 200 | |||
| class DocumentPipelineExecutionLogApi(DocumentResource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, dataset_id, document_id): | |||
| dataset_id = str(dataset_id) | |||
| document_id = str(document_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| log = ( | |||
| db.session.query(DocumentPipelineExecutionLog) | |||
| .filter_by(document_id=document_id) | |||
| .order_by(DocumentPipelineExecutionLog.created_at.desc()) | |||
| .first() | |||
| ) | |||
| if not log: | |||
| return { | |||
| "datasource_info": None, | |||
| "datasource_type": None, | |||
| "input_data": None, | |||
| "datasource_node_id": None, | |||
| }, 200 | |||
| return { | |||
| "datasource_info": json.loads(log.datasource_info), | |||
| "datasource_type": log.datasource_type, | |||
| "input_data": log.input_data, | |||
| "datasource_node_id": log.datasource_node_id, | |||
| }, 200 | |||
| api.add_resource(GetProcessRuleApi, "/datasets/process-rule") | |||
| api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents") | |||
| api.add_resource(DatasetInitApi, "/datasets/init") | |||
| @@ -1033,3 +1071,6 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry") | |||
| api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename") | |||
| api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync") | |||
| api.add_resource( | |||
| DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log" | |||
| ) | |||
| @@ -71,3 +71,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException): | |||
| error_code = "child_chunk_delete_index_error" | |||
| description = "Delete child chunk index failed: {message}" | |||
| code = 500 | |||
| class PipelineNotFoundError(BaseHTTPException): | |||
| error_code = "pipeline_not_found" | |||
| description = "Pipeline not found." | |||
| code = 404 | |||
| @@ -0,0 +1,362 @@ | |||
| from fastapi.encoders import jsonable_encoder | |||
| from flask import make_response, redirect, request | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, reqparse | |||
| from werkzeug.exceptions import Forbidden, NotFound | |||
| from configs import dify_config | |||
| from controllers.console import api | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| setup_required, | |||
| ) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.plugin.impl.oauth import OAuthHandler | |||
| from libs.helper import StrLen | |||
| from libs.login import login_required | |||
| from models.provider_ids import DatasourceProviderID | |||
| from services.datasource_provider_service import DatasourceProviderService | |||
| from services.plugin.oauth_service import OAuthProxyService | |||
| class DatasourcePluginOAuthAuthorizationUrl(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider_id: str): | |||
| user = current_user | |||
| tenant_id = user.current_tenant_id | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| credential_id = request.args.get("credential_id") | |||
| datasource_provider_id = DatasourceProviderID(provider_id) | |||
| provider_name = datasource_provider_id.provider_name | |||
| plugin_id = datasource_provider_id.plugin_id | |||
| oauth_config = DatasourceProviderService().get_oauth_client( | |||
| tenant_id=tenant_id, | |||
| datasource_provider_id=datasource_provider_id, | |||
| ) | |||
| if not oauth_config: | |||
| raise ValueError(f"No OAuth Client Config for {provider_id}") | |||
| context_id = OAuthProxyService.create_proxy_context( | |||
| user_id=current_user.id, | |||
| tenant_id=tenant_id, | |||
| plugin_id=plugin_id, | |||
| provider=provider_name, | |||
| credential_id=credential_id, | |||
| ) | |||
| oauth_handler = OAuthHandler() | |||
| redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback" | |||
| authorization_url_response = oauth_handler.get_authorization_url( | |||
| tenant_id=tenant_id, | |||
| user_id=user.id, | |||
| plugin_id=plugin_id, | |||
| provider=provider_name, | |||
| redirect_uri=redirect_uri, | |||
| system_credentials=oauth_config, | |||
| ) | |||
| response = make_response(jsonable_encoder(authorization_url_response)) | |||
| response.set_cookie( | |||
| "context_id", | |||
| context_id, | |||
| httponly=True, | |||
| samesite="Lax", | |||
| max_age=OAuthProxyService.__MAX_AGE__, | |||
| ) | |||
| return response | |||
| class DatasourceOAuthCallback(Resource): | |||
| @setup_required | |||
| def get(self, provider_id: str): | |||
| context_id = request.cookies.get("context_id") or request.args.get("context_id") | |||
| if not context_id: | |||
| raise Forbidden("context_id not found") | |||
| context = OAuthProxyService.use_proxy_context(context_id) | |||
| if context is None: | |||
| raise Forbidden("Invalid context_id") | |||
| user_id, tenant_id = context.get("user_id"), context.get("tenant_id") | |||
| datasource_provider_id = DatasourceProviderID(provider_id) | |||
| plugin_id = datasource_provider_id.plugin_id | |||
| datasource_provider_service = DatasourceProviderService() | |||
| oauth_client_params = datasource_provider_service.get_oauth_client( | |||
| tenant_id=tenant_id, | |||
| datasource_provider_id=datasource_provider_id, | |||
| ) | |||
| if not oauth_client_params: | |||
| raise NotFound() | |||
| redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback" | |||
| oauth_handler = OAuthHandler() | |||
| oauth_response = oauth_handler.get_credentials( | |||
| tenant_id=tenant_id, | |||
| user_id=user_id, | |||
| plugin_id=plugin_id, | |||
| provider=datasource_provider_id.provider_name, | |||
| redirect_uri=redirect_uri, | |||
| system_credentials=oauth_client_params, | |||
| request=request, | |||
| ) | |||
| credential_id = context.get("credential_id") | |||
| if credential_id: | |||
| datasource_provider_service.reauthorize_datasource_oauth_provider( | |||
| tenant_id=tenant_id, | |||
| provider_id=datasource_provider_id, | |||
| avatar_url=oauth_response.metadata.get("avatar_url") or None, | |||
| name=oauth_response.metadata.get("name") or None, | |||
| expire_at=oauth_response.expires_at, | |||
| credentials=dict(oauth_response.credentials), | |||
| credential_id=context.get("credential_id"), | |||
| ) | |||
| else: | |||
| datasource_provider_service.add_datasource_oauth_provider( | |||
| tenant_id=tenant_id, | |||
| provider_id=datasource_provider_id, | |||
| avatar_url=oauth_response.metadata.get("avatar_url") or None, | |||
| name=oauth_response.metadata.get("name") or None, | |||
| expire_at=oauth_response.expires_at, | |||
| credentials=dict(oauth_response.credentials), | |||
| ) | |||
| return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") | |||
| class DatasourceAuth(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_id: str): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None | |||
| ) | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| datasource_provider_id = DatasourceProviderID(provider_id) | |||
| datasource_provider_service = DatasourceProviderService() | |||
| try: | |||
| datasource_provider_service.add_datasource_api_key_provider( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_id=datasource_provider_id, | |||
| credentials=args["credentials"], | |||
| name=args["name"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {"result": "success"}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider_id: str): | |||
| datasource_provider_id = DatasourceProviderID(provider_id) | |||
| datasource_provider_service = DatasourceProviderService() | |||
| datasources = datasource_provider_service.list_datasource_credentials( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=datasource_provider_id.provider_name, | |||
| plugin_id=datasource_provider_id.plugin_id, | |||
| ) | |||
| return {"result": datasources}, 200 | |||
| class DatasourceAuthDeleteApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_id: str): | |||
| datasource_provider_id = DatasourceProviderID(provider_id) | |||
| plugin_id = datasource_provider_id.plugin_id | |||
| provider_name = datasource_provider_id.provider_name | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| datasource_provider_service = DatasourceProviderService() | |||
| datasource_provider_service.remove_datasource_credentials( | |||
| tenant_id=current_user.current_tenant_id, | |||
| auth_id=args["credential_id"], | |||
| provider=provider_name, | |||
| plugin_id=plugin_id, | |||
| ) | |||
| return {"result": "success"}, 200 | |||
| class DatasourceAuthUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_id: str): | |||
| datasource_provider_id = DatasourceProviderID(provider_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json") | |||
| parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| datasource_provider_service = DatasourceProviderService() | |||
| datasource_provider_service.update_datasource_credentials( | |||
| tenant_id=current_user.current_tenant_id, | |||
| auth_id=args["credential_id"], | |||
| provider=datasource_provider_id.provider_name, | |||
| plugin_id=datasource_provider_id.plugin_id, | |||
| credentials=args.get("credentials", {}), | |||
| name=args.get("name", None), | |||
| ) | |||
| return {"result": "success"}, 201 | |||
| class DatasourceAuthListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| datasource_provider_service = DatasourceProviderService() | |||
| datasources = datasource_provider_service.get_all_datasource_credentials( | |||
| tenant_id=current_user.current_tenant_id | |||
| ) | |||
| return {"result": jsonable_encoder(datasources)}, 200 | |||
| class DatasourceHardCodeAuthListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| datasource_provider_service = DatasourceProviderService() | |||
| datasources = datasource_provider_service.get_hard_code_datasource_credentials( | |||
| tenant_id=current_user.current_tenant_id | |||
| ) | |||
| return {"result": jsonable_encoder(datasources)}, 200 | |||
| class DatasourceAuthOauthCustomClient(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_id: str): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| datasource_provider_id = DatasourceProviderID(provider_id) | |||
| datasource_provider_service = DatasourceProviderService() | |||
| datasource_provider_service.setup_oauth_custom_client_params( | |||
| tenant_id=current_user.current_tenant_id, | |||
| datasource_provider_id=datasource_provider_id, | |||
| client_params=args.get("client_params", {}), | |||
| enabled=args.get("enable_oauth_custom_client", False), | |||
| ) | |||
| return {"result": "success"}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, provider_id: str): | |||
| datasource_provider_id = DatasourceProviderID(provider_id) | |||
| datasource_provider_service = DatasourceProviderService() | |||
| datasource_provider_service.remove_oauth_custom_client_params( | |||
| tenant_id=current_user.current_tenant_id, | |||
| datasource_provider_id=datasource_provider_id, | |||
| ) | |||
| return {"result": "success"}, 200 | |||
| class DatasourceAuthDefaultApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_id: str): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("id", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| datasource_provider_id = DatasourceProviderID(provider_id) | |||
| datasource_provider_service = DatasourceProviderService() | |||
| datasource_provider_service.set_default_datasource_provider( | |||
| tenant_id=current_user.current_tenant_id, | |||
| datasource_provider_id=datasource_provider_id, | |||
| credential_id=args["id"], | |||
| ) | |||
| return {"result": "success"}, 200 | |||
| class DatasourceUpdateProviderNameApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_id: str): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json") | |||
| parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| datasource_provider_id = DatasourceProviderID(provider_id) | |||
| datasource_provider_service = DatasourceProviderService() | |||
| datasource_provider_service.update_datasource_provider_name( | |||
| tenant_id=current_user.current_tenant_id, | |||
| datasource_provider_id=datasource_provider_id, | |||
| name=args["name"], | |||
| credential_id=args["credential_id"], | |||
| ) | |||
| return {"result": "success"}, 200 | |||
| api.add_resource( | |||
| DatasourcePluginOAuthAuthorizationUrl, | |||
| "/oauth/plugin/<path:provider_id>/datasource/get-authorization-url", | |||
| ) | |||
| api.add_resource( | |||
| DatasourceOAuthCallback, | |||
| "/oauth/plugin/<path:provider_id>/datasource/callback", | |||
| ) | |||
| api.add_resource( | |||
| DatasourceAuth, | |||
| "/auth/plugin/datasource/<path:provider_id>", | |||
| ) | |||
| api.add_resource( | |||
| DatasourceAuthUpdateApi, | |||
| "/auth/plugin/datasource/<path:provider_id>/update", | |||
| ) | |||
| api.add_resource( | |||
| DatasourceAuthDeleteApi, | |||
| "/auth/plugin/datasource/<path:provider_id>/delete", | |||
| ) | |||
| api.add_resource( | |||
| DatasourceAuthListApi, | |||
| "/auth/plugin/datasource/list", | |||
| ) | |||
| api.add_resource( | |||
| DatasourceHardCodeAuthListApi, | |||
| "/auth/plugin/datasource/default-list", | |||
| ) | |||
| api.add_resource( | |||
| DatasourceAuthOauthCustomClient, | |||
| "/auth/plugin/datasource/<path:provider_id>/custom-client", | |||
| ) | |||
| api.add_resource( | |||
| DatasourceAuthDefaultApi, | |||
| "/auth/plugin/datasource/<path:provider_id>/default", | |||
| ) | |||
| api.add_resource( | |||
| DatasourceUpdateProviderNameApi, | |||
| "/auth/plugin/datasource/<path:provider_id>/update-name", | |||
| ) | |||
| @@ -0,0 +1,57 @@ | |||
| from flask_restx import ( # type: ignore | |||
| Resource, # type: ignore | |||
| reqparse, | |||
| ) | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.datasets.wraps import get_rag_pipeline | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from libs.login import current_user, login_required | |||
| from models import Account | |||
| from models.dataset import Pipeline | |||
| from services.rag_pipeline.rag_pipeline import RagPipelineService | |||
| class DataSourceContentPreviewApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_rag_pipeline | |||
| def post(self, pipeline: Pipeline, node_id: str): | |||
| """ | |||
| Run datasource content preview | |||
| """ | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("datasource_type", type=str, required=True, location="json") | |||
| parser.add_argument("credential_id", type=str, required=False, location="json") | |||
| args = parser.parse_args() | |||
| inputs = args.get("inputs") | |||
| if inputs is None: | |||
| raise ValueError("missing inputs") | |||
| datasource_type = args.get("datasource_type") | |||
| if datasource_type is None: | |||
| raise ValueError("missing datasource_type") | |||
| rag_pipeline_service = RagPipelineService() | |||
| preview_content = rag_pipeline_service.run_datasource_node_preview( | |||
| pipeline=pipeline, | |||
| node_id=node_id, | |||
| user_inputs=inputs, | |||
| account=current_user, | |||
| datasource_type=datasource_type, | |||
| is_published=True, | |||
| credential_id=args.get("credential_id"), | |||
| ) | |||
| return preview_content, 200 | |||
| api.add_resource( | |||
| DataSourceContentPreviewApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview", | |||
| ) | |||
| @@ -0,0 +1,164 @@ | |||
| import logging | |||
| from flask import request | |||
| from flask_restx import Resource, reqparse | |||
| from sqlalchemy.orm import Session | |||
| from controllers.console import api | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| enterprise_license_required, | |||
| knowledge_pipeline_publish_enabled, | |||
| setup_required, | |||
| ) | |||
| from extensions.ext_database import db | |||
| from libs.login import login_required | |||
| from models.dataset import PipelineCustomizedTemplate | |||
| from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity | |||
| from services.rag_pipeline.rag_pipeline import RagPipelineService | |||
| logger = logging.getLogger(__name__) | |||
| def _validate_name(name): | |||
| if not name or len(name) < 1 or len(name) > 40: | |||
| raise ValueError("Name must be between 1 to 40 characters.") | |||
| return name | |||
| def _validate_description_length(description): | |||
| if len(description) > 400: | |||
| raise ValueError("Description cannot exceed 400 characters.") | |||
| return description | |||
| class PipelineTemplateListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| def get(self): | |||
| type = request.args.get("type", default="built-in", type=str) | |||
| language = request.args.get("language", default="en-US", type=str) | |||
| # get pipeline templates | |||
| pipeline_templates = RagPipelineService.get_pipeline_templates(type, language) | |||
| return pipeline_templates, 200 | |||
| class PipelineTemplateDetailApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| def get(self, template_id: str): | |||
| type = request.args.get("type", default="built-in", type=str) | |||
| rag_pipeline_service = RagPipelineService() | |||
| pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type) | |||
| return pipeline_template, 200 | |||
| class CustomizedPipelineTemplateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| def patch(self, template_id: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="Name must be between 1 to 40 characters.", | |||
| type=_validate_name, | |||
| ) | |||
| parser.add_argument( | |||
| "description", | |||
| type=str, | |||
| nullable=True, | |||
| required=False, | |||
| default="", | |||
| ) | |||
| parser.add_argument( | |||
| "icon_info", | |||
| type=dict, | |||
| location="json", | |||
| nullable=True, | |||
| ) | |||
| args = parser.parse_args() | |||
| pipeline_template_info = PipelineTemplateInfoEntity(**args) | |||
| RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) | |||
| return 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| def delete(self, template_id: str): | |||
| RagPipelineService.delete_customized_pipeline_template(template_id) | |||
| return 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| def post(self, template_id: str): | |||
| with Session(db.engine) as session: | |||
| template = ( | |||
| session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() | |||
| ) | |||
| if not template: | |||
| raise ValueError("Customized pipeline template not found.") | |||
| return {"data": template.yaml_content}, 200 | |||
| class PublishCustomizedPipelineTemplateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| @knowledge_pipeline_publish_enabled | |||
| def post(self, pipeline_id: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="Name must be between 1 to 40 characters.", | |||
| type=_validate_name, | |||
| ) | |||
| parser.add_argument( | |||
| "description", | |||
| type=str, | |||
| nullable=True, | |||
| required=False, | |||
| default="", | |||
| ) | |||
| parser.add_argument( | |||
| "icon_info", | |||
| type=dict, | |||
| location="json", | |||
| nullable=True, | |||
| ) | |||
| args = parser.parse_args() | |||
| rag_pipeline_service = RagPipelineService() | |||
| rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) | |||
| return {"result": "success"} | |||
| api.add_resource( | |||
| PipelineTemplateListApi, | |||
| "/rag/pipeline/templates", | |||
| ) | |||
| api.add_resource( | |||
| PipelineTemplateDetailApi, | |||
| "/rag/pipeline/templates/<string:template_id>", | |||
| ) | |||
| api.add_resource( | |||
| CustomizedPipelineTemplateApi, | |||
| "/rag/pipeline/customized/templates/<string:template_id>", | |||
| ) | |||
| api.add_resource( | |||
| PublishCustomizedPipelineTemplateApi, | |||
| "/rag/pipelines/<string:pipeline_id>/customized/publish", | |||
| ) | |||
| @@ -0,0 +1,110 @@ | |||
| from flask_login import current_user # type: ignore # type: ignore | |||
| from flask_restx import Resource, marshal, reqparse # type: ignore | |||
| from werkzeug.exceptions import Forbidden | |||
| import services | |||
| from controllers.console import api | |||
| from controllers.console.datasets.error import DatasetNameDuplicateError | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| cloud_edition_billing_rate_limit_check, | |||
| setup_required, | |||
| ) | |||
| from fields.dataset_fields import dataset_detail_fields | |||
| from libs.login import login_required | |||
| from models.dataset import DatasetPermissionEnum | |||
| from services.dataset_service import DatasetPermissionService, DatasetService | |||
| from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity | |||
| from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService | |||
| def _validate_name(name): | |||
| if not name or len(name) < 1 or len(name) > 40: | |||
| raise ValueError("Name must be between 1 to 40 characters.") | |||
| return name | |||
| def _validate_description_length(description): | |||
| if len(description) > 400: | |||
| raise ValueError("Description cannot exceed 400 characters.") | |||
| return description | |||
| class CreateRagPipelineDatasetApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "yaml_content", | |||
| type=str, | |||
| nullable=False, | |||
| required=True, | |||
| help="yaml_content is required.", | |||
| ) | |||
| args = parser.parse_args() | |||
| # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator | |||
| if not current_user.is_dataset_editor: | |||
| raise Forbidden() | |||
| rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity( | |||
| name="", | |||
| description="", | |||
| icon_info=IconInfo( | |||
| icon="📙", | |||
| icon_background="#FFF4ED", | |||
| icon_type="emoji", | |||
| ), | |||
| permission=DatasetPermissionEnum.ONLY_ME, | |||
| partial_member_list=None, | |||
| yaml_content=args["yaml_content"], | |||
| ) | |||
| try: | |||
| import_info = RagPipelineDslService.create_rag_pipeline_dataset( | |||
| tenant_id=current_user.current_tenant_id, | |||
| rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, | |||
| ) | |||
| if rag_pipeline_dataset_create_entity.permission == "partial_members": | |||
| DatasetPermissionService.update_partial_member_list( | |||
| current_user.current_tenant_id, | |||
| import_info["dataset_id"], | |||
| rag_pipeline_dataset_create_entity.partial_member_list, | |||
| ) | |||
| except services.errors.dataset.DatasetNameDuplicateError: | |||
| raise DatasetNameDuplicateError() | |||
| return import_info, 201 | |||
| class CreateEmptyRagPipelineDatasetApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def post(self): | |||
| # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator | |||
| if not current_user.is_dataset_editor: | |||
| raise Forbidden() | |||
| dataset = DatasetService.create_empty_rag_pipeline_dataset( | |||
| tenant_id=current_user.current_tenant_id, | |||
| rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity( | |||
| name="", | |||
| description="", | |||
| icon_info=IconInfo( | |||
| icon="📙", | |||
| icon_background="#FFF4ED", | |||
| icon_type="emoji", | |||
| ), | |||
| permission=DatasetPermissionEnum.ONLY_ME, | |||
| partial_member_list=None, | |||
| ), | |||
| ) | |||
| return marshal(dataset, dataset_detail_fields), 201 | |||
| api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset") | |||
| api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset") | |||
| @@ -0,0 +1,417 @@ | |||
| import logging | |||
| from typing import Any, NoReturn | |||
| from flask import Response | |||
| from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.app.error import ( | |||
| DraftWorkflowNotExist, | |||
| ) | |||
| from controllers.console.datasets.wraps import get_rag_pipeline | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from controllers.web.error import InvalidArgumentError, NotFoundError | |||
| from core.variables.segment_group import SegmentGroup | |||
| from core.variables.segments import ArrayFileSegment, FileSegment, Segment | |||
| from core.variables.types import SegmentType | |||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | |||
| from extensions.ext_database import db | |||
| from factories.file_factory import build_from_mapping, build_from_mappings | |||
| from factories.variable_factory import build_segment_with_type | |||
| from libs.login import current_user, login_required | |||
| from models.account import Account | |||
| from models.dataset import Pipeline | |||
| from models.workflow import WorkflowDraftVariable | |||
| from services.rag_pipeline.rag_pipeline import RagPipelineService | |||
| from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService | |||
| logger = logging.getLogger(__name__) | |||
| def _convert_values_to_json_serializable_object(value: Segment) -> Any: | |||
| if isinstance(value, FileSegment): | |||
| return value.value.model_dump() | |||
| elif isinstance(value, ArrayFileSegment): | |||
| return [i.model_dump() for i in value.value] | |||
| elif isinstance(value, SegmentGroup): | |||
| return [_convert_values_to_json_serializable_object(i) for i in value.value] | |||
| else: | |||
| return value.value | |||
| def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: | |||
| value = variable.get_value() | |||
| # create a copy of the value to avoid affecting the model cache. | |||
| value = value.model_copy(deep=True) | |||
| # Refresh the url signature before returning it to client. | |||
| if isinstance(value, FileSegment): | |||
| file = value.value | |||
| file.remote_url = file.generate_url() | |||
| elif isinstance(value, ArrayFileSegment): | |||
| files = value.value | |||
| for file in files: | |||
| file.remote_url = file.generate_url() | |||
| return _convert_values_to_json_serializable_object(value) | |||
| def _create_pagination_parser(): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "page", | |||
| type=inputs.int_range(1, 100_000), | |||
| required=False, | |||
| default=1, | |||
| location="args", | |||
| help="the page of data requested", | |||
| ) | |||
| parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") | |||
| return parser | |||
| _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { | |||
| "id": fields.String, | |||
| "type": fields.String(attribute=lambda model: model.get_variable_type()), | |||
| "name": fields.String, | |||
| "description": fields.String, | |||
| "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), | |||
| "value_type": fields.String, | |||
| "edited": fields.Boolean(attribute=lambda model: model.edited), | |||
| "visible": fields.Boolean, | |||
| } | |||
| _WORKFLOW_DRAFT_VARIABLE_FIELDS = dict( | |||
| _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, | |||
| value=fields.Raw(attribute=_serialize_var_value), | |||
| ) | |||
| _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { | |||
| "id": fields.String, | |||
| "type": fields.String(attribute=lambda _: "env"), | |||
| "name": fields.String, | |||
| "description": fields.String, | |||
| "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), | |||
| "value_type": fields.String, | |||
| "edited": fields.Boolean(attribute=lambda model: model.edited), | |||
| "visible": fields.Boolean, | |||
| } | |||
| _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = { | |||
| "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)), | |||
| } | |||
| def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: | |||
| return var_list.variables | |||
| _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = { | |||
| "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items), | |||
| "total": fields.Raw(), | |||
| } | |||
| _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = { | |||
| "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), | |||
| } | |||
| def _api_prerequisite(f): | |||
| """Common prerequisites for all draft workflow variable APIs. | |||
| It ensures the following conditions are satisfied: | |||
| - Dify has been property setup. | |||
| - The request user has logged in and initialized. | |||
| - The requested app is a workflow or a chat flow. | |||
| - The request user has the edit permission for the app. | |||
| """ | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_rag_pipeline | |||
| def wrapper(*args, **kwargs): | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| return f(*args, **kwargs) | |||
| return wrapper | |||
| class RagPipelineVariableCollectionApi(Resource): | |||
| @_api_prerequisite | |||
| @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) | |||
| def get(self, pipeline: Pipeline): | |||
| """ | |||
| Get draft workflow | |||
| """ | |||
| parser = _create_pagination_parser() | |||
| args = parser.parse_args() | |||
| # fetch draft workflow by app_model | |||
| rag_pipeline_service = RagPipelineService() | |||
| workflow_exist = rag_pipeline_service.is_workflow_exist(pipeline=pipeline) | |||
| if not workflow_exist: | |||
| raise DraftWorkflowNotExist() | |||
| # fetch draft workflow by app_model | |||
| with Session(bind=db.engine, expire_on_commit=False) as session: | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=session, | |||
| ) | |||
| workflow_vars = draft_var_srv.list_variables_without_values( | |||
| app_id=pipeline.id, | |||
| page=args.page, | |||
| limit=args.limit, | |||
| ) | |||
| return workflow_vars | |||
| @_api_prerequisite | |||
| def delete(self, pipeline: Pipeline): | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=db.session(), | |||
| ) | |||
| draft_var_srv.delete_workflow_variables(pipeline.id) | |||
| db.session.commit() | |||
| return Response("", 204) | |||
| def validate_node_id(node_id: str) -> NoReturn | None: | |||
| if node_id in [ | |||
| CONVERSATION_VARIABLE_NODE_ID, | |||
| SYSTEM_VARIABLE_NODE_ID, | |||
| ]: | |||
| # NOTE(QuantumGhost): While we store the system and conversation variables as node variables | |||
| # with specific `node_id` in database, we still want to make the API separated. By disallowing | |||
| # accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`, | |||
| # we mitigate the risk that user of the API depending on the implementation detail of the API. | |||
| # | |||
| # ref: [Hyrum's Law](https://www.hyrumslaw.com/) | |||
| raise InvalidArgumentError( | |||
| f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}", | |||
| ) | |||
| return None | |||
| class RagPipelineNodeVariableCollectionApi(Resource): | |||
| @_api_prerequisite | |||
| @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) | |||
| def get(self, pipeline: Pipeline, node_id: str): | |||
| validate_node_id(node_id) | |||
| with Session(bind=db.engine, expire_on_commit=False) as session: | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=session, | |||
| ) | |||
| node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id) | |||
| return node_vars | |||
| @_api_prerequisite | |||
| def delete(self, pipeline: Pipeline, node_id: str): | |||
| validate_node_id(node_id) | |||
| srv = WorkflowDraftVariableService(db.session()) | |||
| srv.delete_node_variables(pipeline.id, node_id) | |||
| db.session.commit() | |||
| return Response("", 204) | |||
| class RagPipelineVariableApi(Resource): | |||
| _PATCH_NAME_FIELD = "name" | |||
| _PATCH_VALUE_FIELD = "value" | |||
| @_api_prerequisite | |||
| @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) | |||
| def get(self, pipeline: Pipeline, variable_id: str): | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=db.session(), | |||
| ) | |||
| variable = draft_var_srv.get_variable(variable_id=variable_id) | |||
| if variable is None: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| if variable.app_id != pipeline.id: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| return variable | |||
| @_api_prerequisite | |||
| @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) | |||
| def patch(self, pipeline: Pipeline, variable_id: str): | |||
| # Request payload for file types: | |||
| # | |||
| # Local File: | |||
| # | |||
| # { | |||
| # "type": "image", | |||
| # "transfer_method": "local_file", | |||
| # "url": "", | |||
| # "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190" | |||
| # } | |||
| # | |||
| # Remote File: | |||
| # | |||
| # | |||
| # { | |||
| # "type": "image", | |||
| # "transfer_method": "remote_url", | |||
| # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=", | |||
| # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" | |||
| # } | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") | |||
| # Parse 'value' field as-is to maintain its original data structure | |||
| parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=db.session(), | |||
| ) | |||
| args = parser.parse_args(strict=True) | |||
| variable = draft_var_srv.get_variable(variable_id=variable_id) | |||
| if variable is None: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| if variable.app_id != pipeline.id: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| new_name = args.get(self._PATCH_NAME_FIELD, None) | |||
| raw_value = args.get(self._PATCH_VALUE_FIELD, None) | |||
| if new_name is None and raw_value is None: | |||
| return variable | |||
| new_value = None | |||
| if raw_value is not None: | |||
| if variable.value_type == SegmentType.FILE: | |||
| if not isinstance(raw_value, dict): | |||
| raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") | |||
| raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id) | |||
| elif variable.value_type == SegmentType.ARRAY_FILE: | |||
| if not isinstance(raw_value, list): | |||
| raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") | |||
| if len(raw_value) > 0 and not isinstance(raw_value[0], dict): | |||
| raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") | |||
| raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id) | |||
| new_value = build_segment_with_type(variable.value_type, raw_value) | |||
| draft_var_srv.update_variable(variable, name=new_name, value=new_value) | |||
| db.session.commit() | |||
| return variable | |||
| @_api_prerequisite | |||
| def delete(self, pipeline: Pipeline, variable_id: str): | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=db.session(), | |||
| ) | |||
| variable = draft_var_srv.get_variable(variable_id=variable_id) | |||
| if variable is None: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| if variable.app_id != pipeline.id: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| draft_var_srv.delete_variable(variable) | |||
| db.session.commit() | |||
| return Response("", 204) | |||
| class RagPipelineVariableResetApi(Resource): | |||
| @_api_prerequisite | |||
| def put(self, pipeline: Pipeline, variable_id: str): | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=db.session(), | |||
| ) | |||
| rag_pipeline_service = RagPipelineService() | |||
| draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) | |||
| if draft_workflow is None: | |||
| raise NotFoundError( | |||
| f"Draft workflow not found, pipeline_id={pipeline.id}", | |||
| ) | |||
| variable = draft_var_srv.get_variable(variable_id=variable_id) | |||
| if variable is None: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| if variable.app_id != pipeline.id: | |||
| raise NotFoundError(description=f"variable not found, id={variable_id}") | |||
| resetted = draft_var_srv.reset_variable(draft_workflow, variable) | |||
| db.session.commit() | |||
| if resetted is None: | |||
| return Response("", 204) | |||
| else: | |||
| return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS) | |||
| def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList: | |||
| with Session(bind=db.engine, expire_on_commit=False) as session: | |||
| draft_var_srv = WorkflowDraftVariableService( | |||
| session=session, | |||
| ) | |||
| if node_id == CONVERSATION_VARIABLE_NODE_ID: | |||
| draft_vars = draft_var_srv.list_conversation_variables(pipeline.id) | |||
| elif node_id == SYSTEM_VARIABLE_NODE_ID: | |||
| draft_vars = draft_var_srv.list_system_variables(pipeline.id) | |||
| else: | |||
| draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id) | |||
| return draft_vars | |||
| class RagPipelineSystemVariableCollectionApi(Resource): | |||
| @_api_prerequisite | |||
| @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) | |||
| def get(self, pipeline: Pipeline): | |||
| return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID) | |||
| class RagPipelineEnvironmentVariableCollectionApi(Resource): | |||
| @_api_prerequisite | |||
| def get(self, pipeline: Pipeline): | |||
| """ | |||
| Get draft workflow | |||
| """ | |||
| # fetch draft workflow by app_model | |||
| rag_pipeline_service = RagPipelineService() | |||
| workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) | |||
| if workflow is None: | |||
| raise DraftWorkflowNotExist() | |||
| env_vars = workflow.environment_variables | |||
| env_vars_list = [] | |||
| for v in env_vars: | |||
| env_vars_list.append( | |||
| { | |||
| "id": v.id, | |||
| "type": "env", | |||
| "name": v.name, | |||
| "description": v.description, | |||
| "selector": v.selector, | |||
| "value_type": v.value_type.value, | |||
| "value": v.value, | |||
| # Do not track edited for env vars. | |||
| "edited": False, | |||
| "visible": True, | |||
| "editable": True, | |||
| } | |||
| ) | |||
| return {"items": env_vars_list} | |||
| api.add_resource( | |||
| RagPipelineVariableCollectionApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables", | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineNodeVariableCollectionApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables", | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>" | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset" | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables" | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineEnvironmentVariableCollectionApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables", | |||
| ) | |||
| @@ -0,0 +1,147 @@ | |||
| from typing import cast | |||
| from flask_login import current_user # type: ignore | |||
| from flask_restx import Resource, marshal_with, reqparse # type: ignore | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.datasets.wraps import get_rag_pipeline | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| setup_required, | |||
| ) | |||
| from extensions.ext_database import db | |||
| from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields | |||
| from libs.login import login_required | |||
| from models import Account | |||
| from models.dataset import Pipeline | |||
| from services.app_dsl_service import ImportStatus | |||
| from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService | |||
| class RagPipelineImportApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(pipeline_import_fields) | |||
| def post(self): | |||
| # Check user role first | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("mode", type=str, required=True, location="json") | |||
| parser.add_argument("yaml_content", type=str, location="json") | |||
| parser.add_argument("yaml_url", type=str, location="json") | |||
| parser.add_argument("name", type=str, location="json") | |||
| parser.add_argument("description", type=str, location="json") | |||
| parser.add_argument("icon_type", type=str, location="json") | |||
| parser.add_argument("icon", type=str, location="json") | |||
| parser.add_argument("icon_background", type=str, location="json") | |||
| parser.add_argument("pipeline_id", type=str, location="json") | |||
| args = parser.parse_args() | |||
| # Create service with session | |||
| with Session(db.engine) as session: | |||
| import_service = RagPipelineDslService(session) | |||
| # Import app | |||
| account = cast(Account, current_user) | |||
| result = import_service.import_rag_pipeline( | |||
| account=account, | |||
| import_mode=args["mode"], | |||
| yaml_content=args.get("yaml_content"), | |||
| yaml_url=args.get("yaml_url"), | |||
| pipeline_id=args.get("pipeline_id"), | |||
| dataset_name=args.get("name"), | |||
| ) | |||
| session.commit() | |||
| # Return appropriate status code based on result | |||
| status = result.status | |||
| if status == ImportStatus.FAILED.value: | |||
| return result.model_dump(mode="json"), 400 | |||
| elif status == ImportStatus.PENDING.value: | |||
| return result.model_dump(mode="json"), 202 | |||
| return result.model_dump(mode="json"), 200 | |||
| class RagPipelineImportConfirmApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(pipeline_import_fields) | |||
| def post(self, import_id): | |||
| # Check user role first | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| # Create service with session | |||
| with Session(db.engine) as session: | |||
| import_service = RagPipelineDslService(session) | |||
| # Confirm import | |||
| account = cast(Account, current_user) | |||
| result = import_service.confirm_import(import_id=import_id, account=account) | |||
| session.commit() | |||
| # Return appropriate status code based on result | |||
| if result.status == ImportStatus.FAILED.value: | |||
| return result.model_dump(mode="json"), 400 | |||
| return result.model_dump(mode="json"), 200 | |||
| class RagPipelineImportCheckDependenciesApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @get_rag_pipeline | |||
| @account_initialization_required | |||
| @marshal_with(pipeline_import_check_dependencies_fields) | |||
| def get(self, pipeline: Pipeline): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| with Session(db.engine) as session: | |||
| import_service = RagPipelineDslService(session) | |||
| result = import_service.check_dependencies(pipeline=pipeline) | |||
| return result.model_dump(mode="json"), 200 | |||
| class RagPipelineExportApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @get_rag_pipeline | |||
| @account_initialization_required | |||
| def get(self, pipeline: Pipeline): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| # Add include_secret params | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("include_secret", type=bool, default=False, location="args") | |||
| args = parser.parse_args() | |||
| with Session(db.engine) as session: | |||
| export_service = RagPipelineDslService(session) | |||
| result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"]) | |||
| return {"data": result}, 200 | |||
| # Import Rag Pipeline | |||
| api.add_resource( | |||
| RagPipelineImportApi, | |||
| "/rag/pipelines/imports", | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineImportConfirmApi, | |||
| "/rag/pipelines/imports/<string:import_id>/confirm", | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineImportCheckDependenciesApi, | |||
| "/rag/pipelines/imports/<string:pipeline_id>/check-dependencies", | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineExportApi, | |||
| "/rag/pipelines/<string:pipeline_id>/exports", | |||
| ) | |||
| @@ -0,0 +1,47 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import Optional | |||
| from controllers.console.datasets.error import PipelineNotFoundError | |||
| from extensions.ext_database import db | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.dataset import Pipeline | |||
| def get_rag_pipeline( | |||
| view: Optional[Callable] = None, | |||
| ): | |||
| def decorator(view_func): | |||
| @wraps(view_func) | |||
| def decorated_view(*args, **kwargs): | |||
| if not kwargs.get("pipeline_id"): | |||
| raise ValueError("missing pipeline_id in path parameters") | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user is not an account") | |||
| pipeline_id = kwargs.get("pipeline_id") | |||
| pipeline_id = str(pipeline_id) | |||
| del kwargs["pipeline_id"] | |||
| pipeline = ( | |||
| db.session.query(Pipeline) | |||
| .filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not pipeline: | |||
| raise PipelineNotFoundError() | |||
| kwargs["pipeline"] = pipeline | |||
| return view_func(*args, **kwargs) | |||
| return decorated_view | |||
| if view is None: | |||
| return decorator | |||
| else: | |||
| return decorator(view) | |||
| @@ -20,6 +20,7 @@ from core.errors.error import ( | |||
| QuotaExceededError, | |||
| ) | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from core.workflow.graph_engine.manager import GraphEngineManager | |||
| from libs import helper | |||
| from libs.login import current_user | |||
| from models.model import AppMode, InstalledApp | |||
| @@ -78,6 +79,11 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): | |||
| raise NotWorkflowAppError() | |||
| assert current_user is not None | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) | |||
| # Stop using both mechanisms for backward compatibility | |||
| # Legacy stop flag mechanism (without user check) | |||
| AppQueueManager.set_stop_flag_no_user_check(task_id) | |||
| # New graph engine command channel mechanism | |||
| GraphEngineManager.send_stop_command(task_id) | |||
| return {"result": "success"} | |||
| @@ -0,0 +1,35 @@ | |||
| import logging | |||
| from flask_restx import Resource | |||
| from controllers.console import api | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| setup_required, | |||
| ) | |||
| from core.schemas.schema_manager import SchemaManager | |||
| from libs.login import login_required | |||
| logger = logging.getLogger(__name__) | |||
| class SpecSchemaDefinitionsApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| """ | |||
| Get system JSON Schema definitions specification | |||
| Used for frontend component type mapping | |||
| """ | |||
| try: | |||
| schema_manager = SchemaManager() | |||
| schema_definitions = schema_manager.get_all_schema_definitions() | |||
| return schema_definitions, 200 | |||
| except Exception: | |||
| logger.exception("Failed to get schema definitions from local registry") | |||
| # Return empty array as fallback | |||
| return [], 200 | |||
| api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions") | |||
| @@ -21,11 +21,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider | |||
| from core.mcp.error import MCPAuthError, MCPError | |||
| from core.mcp.mcp_client import MCPClient | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.plugin.entities.plugin import ToolProviderID | |||
| from core.plugin.impl.oauth import OAuthHandler | |||
| from core.tools.entities.tool_entities import CredentialType | |||
| from libs.helper import StrLen, alphanumeric, uuid_value | |||
| from libs.login import login_required | |||
| from models.provider_ids import ToolProviderID | |||
| from services.plugin.oauth_service import OAuthProxyService | |||
| from services.tools.api_tools_manage_service import ApiToolManageService | |||
| from services.tools.builtin_tools_manage_service import BuiltinToolManageService | |||
| @@ -261,3 +261,14 @@ def is_allow_transfer_owner(view): | |||
| abort(403) | |||
| return decorated | |||
| def knowledge_pipeline_publish_enabled(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.knowledge_pipeline.publish_enabled: | |||
| return view(*args, **kwargs) | |||
| abort(403) | |||
| return decorated | |||
| @@ -8,7 +8,7 @@ from controllers.common.errors import UnsupportedFileTypeError | |||
| from controllers.files import files_ns | |||
| from core.tools.signature import verify_tool_file_signature | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| from models import db as global_db | |||
| from extensions.ext_database import db as global_db | |||
| @files_ns.route("/tools/<uuid:file_id>.<string:extension>") | |||
| @@ -26,7 +26,8 @@ from core.errors.error import ( | |||
| ) | |||
| from core.helper.trace_id_helper import get_external_trace_id | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus | |||
| from core.workflow.enums import WorkflowExecutionStatus | |||
| from core.workflow.graph_engine.manager import GraphEngineManager | |||
| from extensions.ext_database import db | |||
| from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model | |||
| from libs import helper | |||
| @@ -262,7 +263,12 @@ class WorkflowTaskStopApi(Resource): | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) | |||
| # Stop using both mechanisms for backward compatibility | |||
| # Legacy stop flag mechanism (without user check) | |||
| AppQueueManager.set_stop_flag_no_user_check(task_id) | |||
| # New graph engine command channel mechanism | |||
| GraphEngineManager.send_stop_command(task_id) | |||
| return {"result": "success"} | |||
| @@ -13,13 +13,13 @@ from controllers.service_api.wraps import ( | |||
| validate_dataset_token, | |||
| ) | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.provider_manager import ProviderManager | |||
| from fields.dataset_fields import dataset_detail_fields | |||
| from fields.tag_fields import build_dataset_tag_fields | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.dataset import Dataset, DatasetPermissionEnum | |||
| from models.provider_ids import ModelProviderID | |||
| from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService | |||
| from services.entities.knowledge_entities.knowledge_entities import RetrievalModel | |||
| from services.tag_service import TagService | |||
| @@ -133,6 +133,9 @@ class DocumentAddByTextApi(DatasetApiResource): | |||
| # validate args | |||
| DocumentService.document_create_args_validate(knowledge_config) | |||
| if not current_user: | |||
| raise ValueError("current_user is required") | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id( | |||
| dataset=dataset, | |||
| @@ -21,6 +21,7 @@ from core.errors.error import ( | |||
| QuotaExceededError, | |||
| ) | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from core.workflow.graph_engine.manager import GraphEngineManager | |||
| from libs import helper | |||
| from models.model import App, AppMode, EndUser | |||
| from services.app_generate_service import AppGenerateService | |||
| @@ -110,7 +111,12 @@ class WorkflowTaskStopApi(WebApiResource): | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) | |||
| # Stop using both mechanisms for backward compatibility | |||
| # Legacy stop flag mechanism (without user check) | |||
| AppQueueManager.set_stop_flag_no_user_check(task_id) | |||
| # New graph engine command channel mechanism | |||
| GraphEngineManager.send_stop_command(task_id) | |||
| return {"result": "success"} | |||
| @@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner): | |||
| tenant_id=tenant_id, | |||
| dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [], | |||
| retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, | |||
| return_resource=app_config.additional_features.show_retrieve_source, | |||
| return_resource=( | |||
| app_config.additional_features.show_retrieve_source if app_config.additional_features else False | |||
| ), | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| hit_callback=hit_callback, | |||
| user_id=user_id, | |||
| @@ -4,8 +4,8 @@ from typing import Any | |||
| from core.app.app_config.entities import ModelConfigEntity | |||
| from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | |||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.provider_manager import ProviderManager | |||
| from models.provider_ids import ModelProviderID | |||
| class ModelConfigManager: | |||
| @@ -114,9 +114,9 @@ class VariableEntity(BaseModel): | |||
| hide: bool = False | |||
| max_length: Optional[int] = None | |||
| options: Sequence[str] = Field(default_factory=list) | |||
| allowed_file_types: Sequence[FileType] = Field(default_factory=list) | |||
| allowed_file_extensions: Sequence[str] = Field(default_factory=list) | |||
| allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) | |||
| allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list) | |||
| allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list) | |||
| allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list) | |||
| @field_validator("description", mode="before") | |||
| @classmethod | |||
| @@ -129,6 +129,16 @@ class VariableEntity(BaseModel): | |||
| return v or [] | |||
| class RagPipelineVariableEntity(VariableEntity): | |||
| """ | |||
| Rag Pipeline Variable Entity. | |||
| """ | |||
| tooltips: Optional[str] = None | |||
| placeholder: Optional[str] = None | |||
| belong_to_node_id: str | |||
| class ExternalDataVariableEntity(BaseModel): | |||
| """ | |||
| External Data Variable Entity. | |||
| @@ -288,7 +298,7 @@ class AppConfig(BaseModel): | |||
| tenant_id: str | |||
| app_id: str | |||
| app_mode: AppMode | |||
| additional_features: AppAdditionalFeatures | |||
| additional_features: Optional[AppAdditionalFeatures] = None | |||
| variables: list[VariableEntity] = [] | |||
| sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None | |||
| @@ -1,4 +1,6 @@ | |||
| from core.app.app_config.entities import VariableEntity | |||
| import re | |||
| from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity | |||
| from models.workflow import Workflow | |||
| @@ -20,3 +22,44 @@ class WorkflowVariablesConfigManager: | |||
| variables.append(VariableEntity.model_validate(variable)) | |||
| return variables | |||
| @classmethod | |||
| def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]: | |||
| """ | |||
| Convert workflow start variables to variables | |||
| :param workflow: workflow instance | |||
| """ | |||
| variables = [] | |||
| # get second step node | |||
| rag_pipeline_variables = workflow.rag_pipeline_variables | |||
| if not rag_pipeline_variables: | |||
| return [] | |||
| variables_map = {item["variable"]: item for item in rag_pipeline_variables} | |||
| # get datasource node data | |||
| datasource_node_data = None | |||
| datasource_nodes = workflow.graph_dict.get("nodes", []) | |||
| for datasource_node in datasource_nodes: | |||
| if datasource_node.get("id") == start_node_id: | |||
| datasource_node_data = datasource_node.get("data", {}) | |||
| break | |||
| if datasource_node_data: | |||
| datasource_parameters = datasource_node_data.get("datasource_parameters", {}) | |||
| for key, value in datasource_parameters.items(): | |||
| if value.get("value") and isinstance(value.get("value"), str): | |||
| pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" | |||
| match = re.match(pattern, value["value"]) | |||
| if match: | |||
| full_path = match.group(1) | |||
| last_part = full_path.split(".")[-1] | |||
| variables_map.pop(last_part) | |||
| all_second_step_variables = list(variables_map.values()) | |||
| for item in all_second_step_variables: | |||
| if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared": | |||
| variables.append(RagPipelineVariableEntity.model_validate(item)) | |||
| return variables | |||
| @@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| if invoke_from == InvokeFrom.DEBUGGER: | |||
| # always enable retriever resource in debugger mode | |||
| app_config.additional_features.show_retrieve_source = True | |||
| app_config.additional_features.show_retrieve_source = True # type: ignore | |||
| workflow_run_id = str(uuid.uuid4()) | |||
| # init application generate entity | |||
| @@ -1,11 +1,11 @@ | |||
| import logging | |||
| import time | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional, cast | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from configs import dify_config | |||
| from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner | |||
| @@ -23,16 +23,17 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF | |||
| from core.moderation.base import ModerationError | |||
| from core.moderation.input_moderation import InputModeration | |||
| from core.variables.variables import VariableUnion | |||
| from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities import GraphRuntimeState, VariablePool | |||
| from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import VariableLoader | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models import Workflow | |||
| from models.enums import UserFrom | |||
| from models.model import App, Conversation, Message, MessageAnnotation | |||
| from models.workflow import ConversationVariable, WorkflowType | |||
| from models.workflow import ConversationVariable | |||
| logger = logging.getLogger(__name__) | |||
| @@ -76,23 +77,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| if not app_record: | |||
| raise ValueError("App not found") | |||
| workflow_callbacks: list[WorkflowCallback] = [] | |||
| if dify_config.DEBUG: | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| if self.application_generate_entity.single_iteration_run: | |||
| # if only single iteration run is requested | |||
| graph_runtime_state = GraphRuntimeState( | |||
| variable_pool=VariablePool.empty(), | |||
| start_at=time.time(), | |||
| ) | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | |||
| workflow=self._workflow, | |||
| node_id=self.application_generate_entity.single_iteration_run.node_id, | |||
| user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| elif self.application_generate_entity.single_loop_run: | |||
| # if only single loop run is requested | |||
| graph_runtime_state = GraphRuntimeState( | |||
| variable_pool=VariablePool.empty(), | |||
| start_at=time.time(), | |||
| ) | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( | |||
| workflow=self._workflow, | |||
| node_id=self.application_generate_entity.single_loop_run.node_id, | |||
| user_inputs=dict(self.application_generate_entity.single_loop_run.inputs), | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| else: | |||
| inputs = self.application_generate_entity.inputs | |||
| @@ -144,16 +151,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| ) | |||
| # init graph | |||
| graph = self._init_graph(graph_config=self._workflow.graph_dict) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) | |||
| graph = self._init_graph( | |||
| graph_config=self._workflow.graph_dict, | |||
| graph_runtime_state=graph_runtime_state, | |||
| workflow_id=self._workflow.id, | |||
| tenant_id=self._workflow.tenant_id, | |||
| user_id=self.application_generate_entity.user_id, | |||
| ) | |||
| db.session.close() | |||
| # RUN WORKFLOW | |||
| # Create Redis command channel for this workflow execution | |||
| task_id = self.application_generate_entity.task_id | |||
| channel_key = f"workflow:{task_id}:commands" | |||
| command_channel = RedisChannel(redis_client, channel_key) | |||
| workflow_entry = WorkflowEntry( | |||
| tenant_id=self._workflow.tenant_id, | |||
| app_id=self._workflow.app_id, | |||
| workflow_id=self._workflow.id, | |||
| workflow_type=WorkflowType.value_of(self._workflow.type), | |||
| graph=graph, | |||
| graph_config=self._workflow.graph_dict, | |||
| user_id=self.application_generate_entity.user_id, | |||
| @@ -164,12 +182,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| ), | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| call_depth=self.application_generate_entity.call_depth, | |||
| variable_pool=variable_pool, | |||
| graph_runtime_state=graph_runtime_state, | |||
| command_channel=command_channel, | |||
| ) | |||
| generator = workflow_entry.run( | |||
| callbacks=workflow_callbacks, | |||
| ) | |||
| generator = workflow_entry.run() | |||
| for event in generator: | |||
| self._handle_event(workflow_entry, event) | |||
| @@ -31,14 +31,9 @@ from core.app.entities.queue_entities import ( | |||
| QueueMessageReplaceEvent, | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeInLoopFailedEvent, | |||
| QueueNodeRetryEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| QueueParallelBranchRunStartedEvent, | |||
| QueueParallelBranchRunSucceededEvent, | |||
| QueuePingEvent, | |||
| QueueRetrieverResourcesEvent, | |||
| QueueStopEvent, | |||
| @@ -65,8 +60,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.entities import GraphRuntimeState | |||
| from core.workflow.enums import WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| @@ -395,9 +390,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| def _handle_node_failed_events( | |||
| self, | |||
| event: Union[ | |||
| QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent | |||
| ], | |||
| event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent], | |||
| **kwargs, | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| """Handle various node failure events.""" | |||
| @@ -442,32 +435,6 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector | |||
| ) | |||
| def _handle_parallel_branch_started_event( | |||
| self, event: QueueParallelBranchRunStartedEvent, **kwargs | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| """Handle parallel branch started events.""" | |||
| self._ensure_workflow_initialized() | |||
| parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield parallel_start_resp | |||
| def _handle_parallel_branch_finished_events( | |||
| self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| """Handle parallel branch finished events.""" | |||
| self._ensure_workflow_initialized() | |||
| parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield parallel_finish_resp | |||
| def _handle_iteration_start_event( | |||
| self, event: QueueIterationStartEvent, **kwargs | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| @@ -759,8 +726,6 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| QueueNodeRetryEvent: self._handle_node_retry_event, | |||
| QueueNodeStartedEvent: self._handle_node_started_event, | |||
| QueueNodeSucceededEvent: self._handle_node_succeeded_event, | |||
| # Parallel branch events | |||
| QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event, | |||
| # Iteration events | |||
| QueueIterationStartEvent: self._handle_iteration_start_event, | |||
| QueueIterationNextEvent: self._handle_iteration_next_event, | |||
| @@ -808,8 +773,6 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| event, | |||
| ( | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeInLoopFailedEvent, | |||
| QueueNodeExceptionEvent, | |||
| ), | |||
| ): | |||
| @@ -822,17 +785,6 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| ) | |||
| return | |||
| # Handle parallel branch finished events with isinstance check | |||
| if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)): | |||
| yield from self._handle_parallel_branch_finished_events( | |||
| event, | |||
| graph_runtime_state=graph_runtime_state, | |||
| tts_publisher=tts_publisher, | |||
| trace_manager=trace_manager, | |||
| queue_message=queue_message, | |||
| ) | |||
| return | |||
| # For unhandled events, we continue (original behavior) | |||
| return | |||
| @@ -856,11 +808,6 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| graph_runtime_state = event.graph_runtime_state | |||
| yield from self._handle_workflow_started_event(event) | |||
| case QueueTextChunkEvent(): | |||
| yield from self._handle_text_chunk_event( | |||
| event, tts_publisher=tts_publisher, queue_message=queue_message | |||
| ) | |||
| case QueueErrorEvent(): | |||
| yield from self._handle_error_event(event) | |||
| break | |||
| @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session | |||
| from core.app.app_config.entities import VariableEntityType | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.file import File, FileUploadConfig | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.enums import NodeType | |||
| from core.workflow.repositories.draft_variable_repository import ( | |||
| DraftVariableSaver, | |||
| DraftVariableSaverFactory, | |||
| @@ -126,6 +126,21 @@ class AppQueueManager: | |||
| stopped_cache_key = cls._generate_stopped_cache_key(task_id) | |||
| redis_client.setex(stopped_cache_key, 600, 1) | |||
| @classmethod | |||
| def set_stop_flag_no_user_check(cls, task_id: str) -> None: | |||
| """ | |||
| Set task stop flag without user permission check. | |||
| This method allows stopping workflows without user context. | |||
| :param task_id: The task ID to stop | |||
| :return: | |||
| """ | |||
| if not task_id: | |||
| return | |||
| stopped_cache_key = cls._generate_stopped_cache_key(task_id) | |||
| redis_client.setex(stopped_cache_key, 600, 1) | |||
| def _is_stopped(self) -> bool: | |||
| """ | |||
| Check if task is stopped | |||
| @@ -162,7 +162,9 @@ class ChatAppRunner(AppRunner): | |||
| config=app_config.dataset, | |||
| query=query, | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| show_retrieve_source=app_config.additional_features.show_retrieve_source, | |||
| show_retrieve_source=( | |||
| app_config.additional_features.show_retrieve_source if app_config.additional_features else False | |||
| ), | |||
| hit_callback=hit_callback, | |||
| memory=memory, | |||
| message_id=message.id, | |||
| @@ -1,7 +1,7 @@ | |||
| import time | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import UTC, datetime | |||
| from typing import Any, Optional, Union, cast | |||
| from typing import Any, Optional, Union | |||
| from sqlalchemy.orm import Session | |||
| @@ -16,14 +16,9 @@ from core.app.entities.queue_entities import ( | |||
| QueueLoopStartEvent, | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeInLoopFailedEvent, | |||
| QueueNodeRetryEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| QueueParallelBranchRunStartedEvent, | |||
| QueueParallelBranchRunSucceededEvent, | |||
| ) | |||
| from core.app.entities.task_entities import ( | |||
| AgentLogStreamResponse, | |||
| @@ -36,18 +31,17 @@ from core.app.entities.task_entities import ( | |||
| NodeFinishStreamResponse, | |||
| NodeRetryStreamResponse, | |||
| NodeStartStreamResponse, | |||
| ParallelBranchFinishedStreamResponse, | |||
| ParallelBranchStartStreamResponse, | |||
| WorkflowFinishStreamResponse, | |||
| WorkflowStartStreamResponse, | |||
| ) | |||
| from core.file import FILE_MODEL_IDENTITY, File | |||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||
| from core.tools.entities.tool_entities import ToolProviderType | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.variables.segments import ArrayFileSegment, FileSegment, Segment | |||
| from core.workflow.entities.workflow_execution import WorkflowExecution | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus | |||
| from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution | |||
| from core.workflow.enums import WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||
| from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter | |||
| from libs.datetime_utils import naive_utc_now | |||
| from models import ( | |||
| @@ -174,23 +168,25 @@ class WorkflowResponseConverter: | |||
| # extras logic | |||
| if event.node_type == NodeType.TOOL: | |||
| node_data = cast(ToolNodeData, event.node_data) | |||
| response.data.extras["icon"] = ToolManager.get_tool_icon( | |||
| tenant_id=self._application_generate_entity.app_config.tenant_id, | |||
| provider_type=node_data.provider_type, | |||
| provider_id=node_data.provider_id, | |||
| provider_type=ToolProviderType(event.provider_type), | |||
| provider_id=event.provider_id, | |||
| ) | |||
| elif event.node_type == NodeType.DATASOURCE: | |||
| manager = PluginDatasourceManager() | |||
| provider_entity = manager.fetch_datasource_provider( | |||
| self._application_generate_entity.app_config.tenant_id, | |||
| event.provider_id, | |||
| ) | |||
| response.data.extras["icon"] = provider_entity.declaration.identity.icon | |||
| return response | |||
| def workflow_node_finish_to_stream_response( | |||
| self, | |||
| *, | |||
| event: QueueNodeSucceededEvent | |||
| | QueueNodeFailedEvent | |||
| | QueueNodeInIterationFailedEvent | |||
| | QueueNodeInLoopFailedEvent | |||
| | QueueNodeExceptionEvent, | |||
| event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent, | |||
| task_id: str, | |||
| workflow_node_execution: WorkflowNodeExecution, | |||
| ) -> Optional[NodeFinishStreamResponse]: | |||
| @@ -227,9 +223,6 @@ class WorkflowResponseConverter: | |||
| finished_at=int(workflow_node_execution.finished_at.timestamp()), | |||
| files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| iteration_id=event.in_iteration_id, | |||
| loop_id=event.in_loop_id, | |||
| ), | |||
| @@ -284,50 +277,6 @@ class WorkflowResponseConverter: | |||
| ), | |||
| ) | |||
| def workflow_parallel_branch_start_to_stream_response( | |||
| self, | |||
| *, | |||
| task_id: str, | |||
| workflow_execution_id: str, | |||
| event: QueueParallelBranchRunStartedEvent, | |||
| ) -> ParallelBranchStartStreamResponse: | |||
| return ParallelBranchStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_execution_id, | |||
| data=ParallelBranchStartStreamResponse.Data( | |||
| parallel_id=event.parallel_id, | |||
| parallel_branch_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| iteration_id=event.in_iteration_id, | |||
| loop_id=event.in_loop_id, | |||
| created_at=int(time.time()), | |||
| ), | |||
| ) | |||
| def workflow_parallel_branch_finished_to_stream_response( | |||
| self, | |||
| *, | |||
| task_id: str, | |||
| workflow_execution_id: str, | |||
| event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, | |||
| ) -> ParallelBranchFinishedStreamResponse: | |||
| return ParallelBranchFinishedStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_execution_id, | |||
| data=ParallelBranchFinishedStreamResponse.Data( | |||
| parallel_id=event.parallel_id, | |||
| parallel_branch_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| iteration_id=event.in_iteration_id, | |||
| loop_id=event.in_loop_id, | |||
| status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed", | |||
| error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, | |||
| created_at=int(time.time()), | |||
| ), | |||
| ) | |||
| def workflow_iteration_start_to_stream_response( | |||
| self, | |||
| *, | |||
| @@ -343,14 +292,12 @@ class WorkflowResponseConverter: | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| title=event.node_title, | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| inputs=new_inputs, | |||
| inputs_truncated=truncated, | |||
| metadata=event.metadata or {}, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| ), | |||
| ) | |||
| @@ -368,17 +315,10 @@ class WorkflowResponseConverter: | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| title=event.node_title, | |||
| index=event.index, | |||
| # The `pre_iteration_output` field is not utilized by the frontend. | |||
| # Previously, it was assigned the value of `event.output`. | |||
| pre_iteration_output={}, | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parallel_mode_run_id=event.parallel_mode_run_id, | |||
| duration=event.duration, | |||
| ), | |||
| ) | |||
| @@ -402,7 +342,7 @@ class WorkflowResponseConverter: | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| title=event.node_title, | |||
| outputs=new_outputs, | |||
| outputs_truncated=outputs_truncated, | |||
| created_at=int(time.time()), | |||
| @@ -418,8 +358,6 @@ class WorkflowResponseConverter: | |||
| execution_metadata=event.metadata, | |||
| finished_at=int(time.time()), | |||
| steps=event.steps, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| ), | |||
| ) | |||
| @@ -434,7 +372,7 @@ class WorkflowResponseConverter: | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| title=event.node_title, | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| inputs=new_inputs, | |||
| @@ -459,7 +397,7 @@ class WorkflowResponseConverter: | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| title=event.node_title, | |||
| index=event.index, | |||
| # The `pre_loop_output` field is not utilized by the frontend. | |||
| # Previously, it was assigned the value of `event.output`. | |||
| @@ -469,7 +407,6 @@ class WorkflowResponseConverter: | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parallel_mode_run_id=event.parallel_mode_run_id, | |||
| duration=event.duration, | |||
| ), | |||
| ) | |||
| @@ -492,7 +429,7 @@ class WorkflowResponseConverter: | |||
| id=event.node_id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type.value, | |||
| title=event.node_data.title, | |||
| title=event.node_title, | |||
| outputs=new_outputs, | |||
| outputs_truncated=outputs_truncated, | |||
| created_at=int(time.time()), | |||
| @@ -0,0 +1,95 @@ | |||
| from collections.abc import Generator | |||
| from typing import cast | |||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | |||
| from core.app.entities.task_entities import ( | |||
| AppStreamResponse, | |||
| ErrorStreamResponse, | |||
| NodeFinishStreamResponse, | |||
| NodeStartStreamResponse, | |||
| PingStreamResponse, | |||
| WorkflowAppBlockingResponse, | |||
| WorkflowAppStreamResponse, | |||
| ) | |||
| class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| _blocking_response_type = WorkflowAppBlockingResponse | |||
| @classmethod | |||
| def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] | |||
| """ | |||
| Convert blocking full response. | |||
| :param blocking_response: blocking response | |||
| :return: | |||
| """ | |||
| return dict(blocking_response.to_dict()) | |||
| @classmethod | |||
| def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] | |||
| """ | |||
| Convert blocking simple response. | |||
| :param blocking_response: blocking response | |||
| :return: | |||
| """ | |||
| return cls.convert_blocking_full_response(blocking_response) | |||
| @classmethod | |||
| def convert_stream_full_response( | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream full response. | |||
| :param stream_response: stream response | |||
| :return: | |||
| """ | |||
| for chunk in stream_response: | |||
| chunk = cast(WorkflowAppStreamResponse, chunk) | |||
| sub_stream_response = chunk.stream_response | |||
| if isinstance(sub_stream_response, PingStreamResponse): | |||
| yield "ping" | |||
| continue | |||
| response_chunk = { | |||
| "event": sub_stream_response.event.value, | |||
| "workflow_run_id": chunk.workflow_run_id, | |||
| } | |||
| if isinstance(sub_stream_response, ErrorStreamResponse): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield response_chunk | |||
| @classmethod | |||
| def convert_stream_simple_response( | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream simple response. | |||
| :param stream_response: stream response | |||
| :return: | |||
| """ | |||
| for chunk in stream_response: | |||
| chunk = cast(WorkflowAppStreamResponse, chunk) | |||
| sub_stream_response = chunk.stream_response | |||
| if isinstance(sub_stream_response, PingStreamResponse): | |||
| yield "ping" | |||
| continue | |||
| response_chunk = { | |||
| "event": sub_stream_response.event.value, | |||
| "workflow_run_id": chunk.workflow_run_id, | |||
| } | |||
| if isinstance(sub_stream_response, ErrorStreamResponse): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): | |||
| response_chunk.update(sub_stream_response.to_ignore_detail_dict()) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield response_chunk | |||
| @@ -0,0 +1,66 @@ | |||
| from core.app.app_config.base_app_config_manager import BaseAppConfigManager | |||
| from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager | |||
| from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager | |||
| from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager | |||
| from models.dataset import Pipeline | |||
| from models.model import AppMode | |||
| from models.workflow import Workflow | |||
| class PipelineConfig(WorkflowUIBasedAppConfig): | |||
| """ | |||
| Pipeline Config Entity. | |||
| """ | |||
| rag_pipeline_variables: list[RagPipelineVariableEntity] = [] | |||
| pass | |||
| class PipelineConfigManager(BaseAppConfigManager): | |||
| @classmethod | |||
| def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig: | |||
| pipeline_config = PipelineConfig( | |||
| tenant_id=pipeline.tenant_id, | |||
| app_id=pipeline.id, | |||
| app_mode=AppMode.RAG_PIPELINE, | |||
| workflow_id=workflow.id, | |||
| rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable( | |||
| workflow=workflow, start_node_id=start_node_id | |||
| ), | |||
| ) | |||
| return pipeline_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: | |||
| """ | |||
| Validate for pipeline config | |||
| :param tenant_id: tenant id | |||
| :param config: app model config args | |||
| :param only_structure_validate: only validate the structure of the config | |||
| """ | |||
| related_config_keys = [] | |||
| # file upload validation | |||
| config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) | |||
| related_config_keys.extend(current_related_config_keys) | |||
| # text_to_speech | |||
| config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) | |||
| related_config_keys.extend(current_related_config_keys) | |||
| # moderation validation | |||
| config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( | |||
| tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate | |||
| ) | |||
| related_config_keys.extend(current_related_config_keys) | |||
| related_config_keys = list(set(related_config_keys)) | |||
| # Filter out extra parameters | |||
| filtered_config = {key: config.get(key) for key in related_config_keys} | |||
| return filtered_config | |||
| @@ -0,0 +1,802 @@ | |||
| import contextvars | |||
| import datetime | |||
| import json | |||
| import logging | |||
| import secrets | |||
| import threading | |||
| import time | |||
| import uuid | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Any, Literal, Optional, Union, cast, overload | |||
| from flask import Flask, current_app | |||
| from pydantic import ValidationError | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| import contexts | |||
| from configs import dify_config | |||
| from core.app.apps.base_app_generator import BaseAppGenerator | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.apps.exc import GenerateTaskStoppedError | |||
| from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager | |||
| from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager | |||
| from core.app.apps.pipeline.pipeline_runner import PipelineRunner | |||
| from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter | |||
| from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||
| from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity | |||
| from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | |||
| from core.datasource.entities.datasource_entities import ( | |||
| DatasourceProviderType, | |||
| OnlineDriveBrowseFilesRequest, | |||
| ) | |||
| from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin | |||
| from core.entities.knowledge_entities import PipelineDataset, PipelineDocument | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.rag.index_processor.constant.built_in_field import BuiltInField | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository | |||
| from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader | |||
| from extensions.ext_database import db | |||
| from libs.flask_utils import preserve_flask_contexts | |||
| from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| from models.model import AppMode | |||
| from services.dataset_service import DocumentService | |||
| from services.datasource_provider_service import DatasourceProviderService | |||
| from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService | |||
| from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task | |||
| logger = logging.getLogger(__name__) | |||
| class PipelineGenerator(BaseAppGenerator): | |||
| @overload | |||
| def generate( | |||
| self, | |||
| *, | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[True], | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... | |||
| @overload | |||
| def generate( | |||
| self, | |||
| *, | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[False], | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Mapping[str, Any]: ... | |||
| @overload | |||
| def generate( | |||
| self, | |||
| *, | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool, | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... | |||
| def generate( | |||
| self, | |||
| *, | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| call_depth: int = 0, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: | |||
| # Add null check for dataset | |||
| dataset = pipeline.dataset | |||
| if not dataset: | |||
| raise ValueError("Pipeline dataset is required") | |||
| inputs: Mapping[str, Any] = args["inputs"] | |||
| start_node_id: str = args["start_node_id"] | |||
| datasource_type: str = args["datasource_type"] | |||
| datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list( | |||
| datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user | |||
| ) | |||
| batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000) | |||
| # convert to app config | |||
| pipeline_config = PipelineConfigManager.get_pipeline_config( | |||
| pipeline=pipeline, workflow=workflow, start_node_id=start_node_id | |||
| ) | |||
| documents = [] | |||
| if invoke_from == InvokeFrom.PUBLISHED: | |||
| for datasource_info in datasource_info_list: | |||
| position = DocumentService.get_documents_position(dataset.id) | |||
| document = self._build_document( | |||
| tenant_id=pipeline.tenant_id, | |||
| dataset_id=dataset.id, | |||
| built_in_field_enabled=dataset.built_in_field_enabled, | |||
| datasource_type=datasource_type, | |||
| datasource_info=datasource_info, | |||
| created_from="rag-pipeline", | |||
| position=position, | |||
| account=user, | |||
| batch=batch, | |||
| document_form=dataset.chunk_structure, | |||
| ) | |||
| db.session.add(document) | |||
| documents.append(document) | |||
| db.session.commit() | |||
| # run in child thread | |||
| for i, datasource_info in enumerate(datasource_info_list): | |||
| workflow_run_id = str(uuid.uuid4()) | |||
| document_id = None | |||
| if invoke_from == InvokeFrom.PUBLISHED: | |||
| document_id = documents[i].id | |||
| document_pipeline_execution_log = DocumentPipelineExecutionLog( | |||
| document_id=document_id, | |||
| datasource_type=datasource_type, | |||
| datasource_info=json.dumps(datasource_info), | |||
| datasource_node_id=start_node_id, | |||
| input_data=inputs, | |||
| pipeline_id=pipeline.id, | |||
| created_by=user.id, | |||
| ) | |||
| db.session.add(document_pipeline_execution_log) | |||
| db.session.commit() | |||
| application_generate_entity = RagPipelineGenerateEntity( | |||
| task_id=str(uuid.uuid4()), | |||
| app_config=pipeline_config, | |||
| pipeline_config=pipeline_config, | |||
| datasource_type=datasource_type, | |||
| datasource_info=datasource_info, | |||
| dataset_id=dataset.id, | |||
| start_node_id=start_node_id, | |||
| batch=batch, | |||
| document_id=document_id, | |||
| inputs=self._prepare_user_inputs( | |||
| user_inputs=inputs, | |||
| variables=pipeline_config.rag_pipeline_variables, | |||
| tenant_id=pipeline.tenant_id, | |||
| strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, | |||
| ), | |||
| files=[], | |||
| user_id=user.id, | |||
| stream=streaming, | |||
| invoke_from=invoke_from, | |||
| call_depth=call_depth, | |||
| workflow_execution_id=workflow_run_id, | |||
| ) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| if invoke_from == InvokeFrom.DEBUGGER: | |||
| workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING | |||
| else: | |||
| workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=workflow_triggered_from, | |||
| ) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, | |||
| ) | |||
| if invoke_from == InvokeFrom.DEBUGGER: | |||
| return self._generate( | |||
| flask_app=current_app._get_current_object(), # type: ignore | |||
| context=contextvars.copy_context(), | |||
| pipeline=pipeline, | |||
| workflow_id=workflow.id, | |||
| user=user, | |||
| application_generate_entity=application_generate_entity, | |||
| invoke_from=invoke_from, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| ) | |||
| else: | |||
| rag_pipeline_run_task.delay( # type: ignore | |||
| pipeline_id=pipeline.id, | |||
| user_id=user.id, | |||
| tenant_id=pipeline.tenant_id, | |||
| workflow_id=workflow.id, | |||
| streaming=streaming, | |||
| workflow_execution_id=workflow_run_id, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| application_generate_entity=application_generate_entity.model_dump(), | |||
| ) | |||
| # return batch, dataset, documents | |||
| return { | |||
| "batch": batch, | |||
| "dataset": PipelineDataset( | |||
| id=dataset.id, | |||
| name=dataset.name, | |||
| description=dataset.description, | |||
| chunk_structure=dataset.chunk_structure, | |||
| ).model_dump(), | |||
| "documents": [ | |||
| PipelineDocument( | |||
| id=document.id, | |||
| position=document.position, | |||
| data_source_type=document.data_source_type, | |||
| data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, | |||
| name=document.name, | |||
| indexing_status=document.indexing_status, | |||
| error=document.error, | |||
| enabled=document.enabled, | |||
| ).model_dump() | |||
| for document in documents | |||
| ], | |||
| } | |||
| def _generate( | |||
| self, | |||
| *, | |||
| flask_app: Flask, | |||
| context: contextvars.Context, | |||
| pipeline: Pipeline, | |||
| workflow_id: str, | |||
| user: Union[Account, EndUser], | |||
| application_generate_entity: RagPipelineGenerateEntity, | |||
| invoke_from: InvokeFrom, | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| streaming: bool = True, | |||
| variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: | |||
| """ | |||
| Generate App response. | |||
| :param pipeline: Pipeline | |||
| :param workflow: Workflow | |||
| :param user: account or end user | |||
| :param application_generate_entity: application generate entity | |||
| :param invoke_from: invoke from source | |||
| :param workflow_execution_repository: repository for workflow execution | |||
| :param workflow_node_execution_repository: repository for workflow node execution | |||
| :param streaming: is stream | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| with preserve_flask_contexts(flask_app, context_vars=context): | |||
| # init queue manager | |||
| workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first() | |||
| if not workflow: | |||
| raise ValueError(f"Workflow not found: {workflow_id}") | |||
| queue_manager = PipelineQueueManager( | |||
| task_id=application_generate_entity.task_id, | |||
| user_id=application_generate_entity.user_id, | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| app_mode=AppMode.RAG_PIPELINE, | |||
| ) | |||
| context = contextvars.copy_context() | |||
| # new thread | |||
| worker_thread = threading.Thread( | |||
| target=self._generate_worker, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "context": context, | |||
| "queue_manager": queue_manager, | |||
| "application_generate_entity": application_generate_entity, | |||
| "workflow_thread_pool_id": workflow_thread_pool_id, | |||
| "variable_loader": variable_loader, | |||
| }, | |||
| ) | |||
| worker_thread.start() | |||
| draft_var_saver_factory = self._get_draft_var_saver_factory( | |||
| invoke_from, | |||
| ) | |||
| # return response or stream generator | |||
| response = self._handle_response( | |||
| application_generate_entity=application_generate_entity, | |||
| workflow=workflow, | |||
| queue_manager=queue_manager, | |||
| user=user, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| stream=streaming, | |||
| draft_var_saver_factory=draft_var_saver_factory, | |||
| ) | |||
| return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) | |||
| def single_iteration_generate( | |||
| self, | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user: Account | EndUser, | |||
| args: Mapping[str, Any], | |||
| streaming: bool = True, | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: | |||
| """ | |||
| Generate App response. | |||
| :param app_model: App | |||
| :param workflow: Workflow | |||
| :param node_id: the node id | |||
| :param user: account or end user | |||
| :param args: request args | |||
| :param streaming: is streamed | |||
| """ | |||
| if not node_id: | |||
| raise ValueError("node_id is required") | |||
| if args.get("inputs") is None: | |||
| raise ValueError("inputs is required") | |||
| # convert to app config | |||
| pipeline_config = PipelineConfigManager.get_pipeline_config( | |||
| pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared") | |||
| ) | |||
| dataset = pipeline.dataset | |||
| if not dataset: | |||
| raise ValueError("Pipeline dataset is required") | |||
| # init application generate entity - use RagPipelineGenerateEntity instead | |||
| application_generate_entity = RagPipelineGenerateEntity( | |||
| task_id=str(uuid.uuid4()), | |||
| app_config=pipeline_config, | |||
| pipeline_config=pipeline_config, | |||
| datasource_type=args.get("datasource_type", ""), | |||
| datasource_info=args.get("datasource_info", {}), | |||
| dataset_id=dataset.id, | |||
| batch=args.get("batch", ""), | |||
| document_id=args.get("document_id"), | |||
| inputs={}, | |||
| files=[], | |||
| user_id=user.id, | |||
| stream=streaming, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| workflow_execution_id=str(uuid.uuid4()), | |||
| ) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING, | |||
| ) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| draft_var_srv = WorkflowDraftVariableService(db.session()) | |||
| draft_var_srv.prefill_conversation_variable_default_values(workflow) | |||
| var_loader = DraftVarLoader( | |||
| engine=db.engine, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| ) | |||
| return self._generate( | |||
| flask_app=current_app._get_current_object(), # type: ignore | |||
| pipeline=pipeline, | |||
| workflow_id=workflow.id, | |||
| user=user, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| variable_loader=var_loader, | |||
| ) | |||
| def single_loop_generate( | |||
| self, | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user: Account | EndUser, | |||
| args: Mapping[str, Any], | |||
| streaming: bool = True, | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: | |||
| """ | |||
| Generate App response. | |||
| :param app_model: App | |||
| :param workflow: Workflow | |||
| :param node_id: the node id | |||
| :param user: account or end user | |||
| :param args: request args | |||
| :param streaming: is streamed | |||
| """ | |||
| if not node_id: | |||
| raise ValueError("node_id is required") | |||
| if args.get("inputs") is None: | |||
| raise ValueError("inputs is required") | |||
| dataset = pipeline.dataset | |||
| if not dataset: | |||
| raise ValueError("Pipeline dataset is required") | |||
| # convert to app config | |||
| pipeline_config = PipelineConfigManager.get_pipeline_config( | |||
| pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared") | |||
| ) | |||
| # init application generate entity | |||
| application_generate_entity = RagPipelineGenerateEntity( | |||
| task_id=str(uuid.uuid4()), | |||
| app_config=pipeline_config, | |||
| pipeline_config=pipeline_config, | |||
| datasource_type=args.get("datasource_type", ""), | |||
| datasource_info=args.get("datasource_info", {}), | |||
| batch=args.get("batch", ""), | |||
| document_id=args.get("document_id"), | |||
| dataset_id=dataset.id, | |||
| inputs={}, | |||
| files=[], | |||
| user_id=user.id, | |||
| stream=streaming, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| extras={"auto_generate_conversation_name": False}, | |||
| single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), | |||
| workflow_execution_id=str(uuid.uuid4()), | |||
| ) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING, | |||
| ) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| draft_var_srv = WorkflowDraftVariableService(db.session()) | |||
| draft_var_srv.prefill_conversation_variable_default_values(workflow) | |||
| var_loader = DraftVarLoader( | |||
| engine=db.engine, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| ) | |||
| return self._generate( | |||
| flask_app=current_app._get_current_object(), # type: ignore | |||
| pipeline=pipeline, | |||
| workflow_id=workflow.id, | |||
| user=user, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| variable_loader=var_loader, | |||
| ) | |||
| def _generate_worker( | |||
| self, | |||
| flask_app: Flask, | |||
| application_generate_entity: RagPipelineGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| context: contextvars.Context, | |||
| variable_loader: VariableLoader, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> None: | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: queue manager | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| :return: | |||
| """ | |||
| with preserve_flask_contexts(flask_app, context_vars=context): | |||
| try: | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow = session.scalar( | |||
| select(Workflow).where( | |||
| Workflow.tenant_id == application_generate_entity.app_config.tenant_id, | |||
| Workflow.app_id == application_generate_entity.app_config.app_id, | |||
| Workflow.id == application_generate_entity.app_config.workflow_id, | |||
| ) | |||
| ) | |||
| if workflow is None: | |||
| raise ValueError("Workflow not found") | |||
| # Determine system_user_id based on invocation source | |||
| is_external_api_call = application_generate_entity.invoke_from in { | |||
| InvokeFrom.WEB_APP, | |||
| InvokeFrom.SERVICE_API, | |||
| } | |||
| if is_external_api_call: | |||
| # For external API calls, use end user's session ID | |||
| end_user = session.scalar( | |||
| select(EndUser).where(EndUser.id == application_generate_entity.user_id) | |||
| ) | |||
| system_user_id = end_user.session_id if end_user else "" | |||
| else: | |||
| # For internal calls, use the original user ID | |||
| system_user_id = application_generate_entity.user_id | |||
| # workflow app | |||
| runner = PipelineRunner( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| variable_loader=variable_loader, | |||
| workflow=workflow, | |||
| system_user_id=system_user_id, | |||
| ) | |||
| runner.run() | |||
| except GenerateTaskStoppedError: | |||
| pass | |||
| except InvokeAuthorizationError: | |||
| queue_manager.publish_error( | |||
| InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| except ValidationError as e: | |||
| logger.exception("Validation Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| except ValueError as e: | |||
| if dify_config.DEBUG: | |||
| logger.exception("Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| except Exception as e: | |||
| logger.exception("Unknown Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| finally: | |||
| db.session.close() | |||
| def _handle_response( | |||
| self, | |||
| application_generate_entity: RagPipelineGenerateEntity, | |||
| workflow: Workflow, | |||
| queue_manager: AppQueueManager, | |||
| user: Union[Account, EndUser], | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| draft_var_saver_factory: DraftVariableSaverFactory, | |||
| stream: bool = False, | |||
| ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: | |||
| """ | |||
| Handle response. | |||
| :param application_generate_entity: application generate entity | |||
| :param workflow: workflow | |||
| :param queue_manager: queue manager | |||
| :param user: account or end user | |||
| :param stream: is stream | |||
| :param workflow_node_execution_repository: optional repository for workflow node execution | |||
| :return: | |||
| """ | |||
| # init generate task pipeline | |||
| generate_task_pipeline = WorkflowAppGenerateTaskPipeline( | |||
| application_generate_entity=application_generate_entity, | |||
| workflow=workflow, | |||
| queue_manager=queue_manager, | |||
| user=user, | |||
| stream=stream, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| draft_var_saver_factory=draft_var_saver_factory, | |||
| ) | |||
| try: | |||
| return generate_task_pipeline.process() | |||
| except ValueError as e: | |||
| if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error | |||
| raise GenerateTaskStoppedError() | |||
| else: | |||
| logger.exception( | |||
| "Fails to process generate task pipeline, task_id: %r", | |||
| application_generate_entity.task_id, | |||
| ) | |||
| raise e | |||
| def _build_document( | |||
| self, | |||
| tenant_id: str, | |||
| dataset_id: str, | |||
| built_in_field_enabled: bool, | |||
| datasource_type: str, | |||
| datasource_info: Mapping[str, Any], | |||
| created_from: str, | |||
| position: int, | |||
| account: Union[Account, EndUser], | |||
| batch: str, | |||
| document_form: str, | |||
| ): | |||
| if datasource_type == "local_file": | |||
| name = datasource_info["name"] | |||
| elif datasource_type == "online_document": | |||
| name = datasource_info["page"]["page_name"] | |||
| elif datasource_type == "website_crawl": | |||
| name = datasource_info["title"] | |||
| elif datasource_type == "online_drive": | |||
| name = datasource_info["key"] | |||
| else: | |||
| raise ValueError(f"Unsupported datasource type: {datasource_type}") | |||
| document = Document( | |||
| tenant_id=tenant_id, | |||
| dataset_id=dataset_id, | |||
| position=position, | |||
| data_source_type=datasource_type, | |||
| data_source_info=json.dumps(datasource_info), | |||
| batch=batch, | |||
| name=name, | |||
| created_from=created_from, | |||
| created_by=account.id, | |||
| doc_form=document_form, | |||
| ) | |||
| doc_metadata = {} | |||
| if built_in_field_enabled: | |||
| doc_metadata = { | |||
| BuiltInField.document_name: name, | |||
| BuiltInField.uploader: account.name, | |||
| BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), | |||
| BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), | |||
| BuiltInField.source: datasource_type, | |||
| } | |||
| if doc_metadata: | |||
| document.doc_metadata = doc_metadata | |||
| return document | |||
| def _format_datasource_info_list( | |||
| self, | |||
| datasource_type: str, | |||
| datasource_info_list: list[Mapping[str, Any]], | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| start_node_id: str, | |||
| user: Union[Account, EndUser], | |||
| ) -> list[Mapping[str, Any]]: | |||
| """ | |||
| Format datasource info list. | |||
| """ | |||
| if datasource_type == "online_drive": | |||
| all_files = [] | |||
| datasource_node_data = None | |||
| datasource_nodes = workflow.graph_dict.get("nodes", []) | |||
| for datasource_node in datasource_nodes: | |||
| if datasource_node.get("id") == start_node_id: | |||
| datasource_node_data = datasource_node.get("data", {}) | |||
| break | |||
| if not datasource_node_data: | |||
| raise ValueError("Datasource node data not found") | |||
| from core.datasource.datasource_manager import DatasourceManager | |||
| datasource_runtime = DatasourceManager.get_datasource_runtime( | |||
| provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", | |||
| datasource_name=datasource_node_data.get("datasource_name"), | |||
| tenant_id=pipeline.tenant_id, | |||
| datasource_type=DatasourceProviderType(datasource_type), | |||
| ) | |||
| datasource_provider_service = DatasourceProviderService() | |||
| credentials = datasource_provider_service.get_datasource_credentials( | |||
| tenant_id=pipeline.tenant_id, | |||
| provider=datasource_node_data.get("provider_name"), | |||
| plugin_id=datasource_node_data.get("plugin_id"), | |||
| credential_id=datasource_node_data.get("credential_id"), | |||
| ) | |||
| if credentials: | |||
| datasource_runtime.runtime.credentials = credentials | |||
| datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) | |||
| for datasource_info in datasource_info_list: | |||
| if datasource_info.get("id") and datasource_info.get("type") == "folder": | |||
| # get all files in the folder | |||
| self._get_files_in_folder( | |||
| datasource_runtime, | |||
| datasource_info.get("id", ""), | |||
| datasource_info.get("bucket", None), | |||
| user.id, | |||
| all_files, | |||
| datasource_info, | |||
| None, | |||
| ) | |||
| else: | |||
| all_files.append( | |||
| { | |||
| "id": datasource_info.get("id", ""), | |||
| "bucket": datasource_info.get("bucket", None), | |||
| } | |||
| ) | |||
| return all_files | |||
| else: | |||
| return datasource_info_list | |||
| def _get_files_in_folder( | |||
| self, | |||
| datasource_runtime: OnlineDriveDatasourcePlugin, | |||
| prefix: str, | |||
| bucket: Optional[str], | |||
| user_id: str, | |||
| all_files: list, | |||
| datasource_info: Mapping[str, Any], | |||
| next_page_parameters: Optional[dict] = None, | |||
| ): | |||
| """ | |||
| Get files in a folder. | |||
| """ | |||
| result_generator = datasource_runtime.online_drive_browse_files( | |||
| user_id=user_id, | |||
| request=OnlineDriveBrowseFilesRequest( | |||
| bucket=bucket, | |||
| prefix=prefix, | |||
| max_keys=20, | |||
| next_page_parameters=next_page_parameters, | |||
| ), | |||
| provider_type=datasource_runtime.datasource_provider_type(), | |||
| ) | |||
| is_truncated = False | |||
| last_file_key = None | |||
| for result in result_generator: | |||
| for files in result.result: | |||
| for file in files.files: | |||
| if file.type == "folder": | |||
| self._get_files_in_folder( | |||
| datasource_runtime, | |||
| file.id, | |||
| bucket, | |||
| user_id, | |||
| all_files, | |||
| datasource_info, | |||
| None, | |||
| ) | |||
| else: | |||
| all_files.append( | |||
| { | |||
| "id": file.id, | |||
| "bucket": bucket, | |||
| } | |||
| ) | |||
| is_truncated = files.is_truncated | |||
| next_page_parameters = files.next_page_parameters | |||
| if is_truncated: | |||
| self._get_files_in_folder( | |||
| datasource_runtime, prefix, bucket, user_id, all_files, datasource_info, next_page_parameters | |||
| ) | |||
| @@ -0,0 +1,45 @@ | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.apps.exc import GenerateTaskStoppedError | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.entities.queue_entities import ( | |||
| AppQueueEvent, | |||
| QueueErrorEvent, | |||
| QueueMessageEndEvent, | |||
| QueueStopEvent, | |||
| QueueWorkflowFailedEvent, | |||
| QueueWorkflowPartialSuccessEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| WorkflowQueueMessage, | |||
| ) | |||
| class PipelineQueueManager(AppQueueManager): | |||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: | |||
| super().__init__(task_id, user_id, invoke_from) | |||
| self._app_mode = app_mode | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||
| """ | |||
| Publish event to queue | |||
| :param event: | |||
| :param pub_from: | |||
| :return: | |||
| """ | |||
| message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event) | |||
| self._q.put(message) | |||
| if isinstance( | |||
| event, | |||
| QueueStopEvent | |||
| | QueueErrorEvent | |||
| | QueueMessageEndEvent | |||
| | QueueWorkflowSucceededEvent | |||
| | QueueWorkflowFailedEvent | |||
| | QueueWorkflowPartialSuccessEvent, | |||
| ): | |||
| self.stop_listen() | |||
| if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): | |||
| raise GenerateTaskStoppedError() | |||
| @@ -0,0 +1,280 @@ | |||
| import logging | |||
| import time | |||
| from typing import Optional, cast | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig | |||
| from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner | |||
| from core.app.entities.app_invoke_entities import ( | |||
| InvokeFrom, | |||
| RagPipelineGenerateEntity, | |||
| ) | |||
| from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput | |||
| from core.workflow.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent | |||
| from core.workflow.nodes.node_factory import DifyNodeFactory | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import VariableLoader | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models.dataset import Document, Pipeline | |||
| from models.enums import UserFrom | |||
| from models.model import EndUser | |||
| from models.workflow import Workflow | |||
| logger = logging.getLogger(__name__) | |||
| class PipelineRunner(WorkflowBasedAppRunner): | |||
| """ | |||
| Pipeline Application Runner | |||
| """ | |||
| def __init__( | |||
| self, | |||
| application_generate_entity: RagPipelineGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| variable_loader: VariableLoader, | |||
| workflow: Workflow, | |||
| system_user_id: str, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> None: | |||
| """ | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: application queue manager | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| super().__init__( | |||
| queue_manager=queue_manager, | |||
| variable_loader=variable_loader, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| ) | |||
| self.application_generate_entity = application_generate_entity | |||
| self.workflow_thread_pool_id = workflow_thread_pool_id | |||
| self._workflow = workflow | |||
| self._sys_user_id = system_user_id | |||
| def _get_app_id(self) -> str: | |||
| return self.application_generate_entity.app_config.app_id | |||
| def run(self) -> None: | |||
| """ | |||
| Run application | |||
| """ | |||
| app_config = self.application_generate_entity.app_config | |||
| app_config = cast(PipelineConfig, app_config) | |||
| user_id = None | |||
| if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() | |||
| if end_user: | |||
| user_id = end_user.session_id | |||
| else: | |||
| user_id = self.application_generate_entity.user_id | |||
| pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first() | |||
| if not pipeline: | |||
| raise ValueError("Pipeline not found") | |||
| workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id) | |||
| if not workflow: | |||
| raise ValueError("Workflow not initialized") | |||
| db.session.close() | |||
| # if only single iteration run is requested | |||
| if self.application_generate_entity.single_iteration_run: | |||
| graph_runtime_state = GraphRuntimeState( | |||
| variable_pool=VariablePool.empty(), | |||
| start_at=time.time(), | |||
| ) | |||
| # if only single iteration run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | |||
| workflow=workflow, | |||
| node_id=self.application_generate_entity.single_iteration_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_iteration_run.inputs, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| elif self.application_generate_entity.single_loop_run: | |||
| graph_runtime_state = GraphRuntimeState( | |||
| variable_pool=VariablePool.empty(), | |||
| start_at=time.time(), | |||
| ) | |||
| # if only single loop run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( | |||
| workflow=workflow, | |||
| node_id=self.application_generate_entity.single_loop_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_loop_run.inputs, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| else: | |||
| inputs = self.application_generate_entity.inputs | |||
| files = self.application_generate_entity.files | |||
| # Create a variable pool. | |||
| system_inputs = SystemVariable( | |||
| files=files, | |||
| user_id=user_id, | |||
| app_id=app_config.app_id, | |||
| workflow_id=app_config.workflow_id, | |||
| workflow_execution_id=self.application_generate_entity.workflow_execution_id, | |||
| document_id=self.application_generate_entity.document_id, | |||
| batch=self.application_generate_entity.batch, | |||
| dataset_id=self.application_generate_entity.dataset_id, | |||
| datasource_type=self.application_generate_entity.datasource_type, | |||
| datasource_info=self.application_generate_entity.datasource_info, | |||
| invoke_from=self.application_generate_entity.invoke_from.value, | |||
| ) | |||
| rag_pipeline_variables = [] | |||
| if workflow.rag_pipeline_variables: | |||
| for v in workflow.rag_pipeline_variables: | |||
| rag_pipeline_variable = RAGPipelineVariable(**v) | |||
| if ( | |||
| rag_pipeline_variable.belong_to_node_id | |||
| in (self.application_generate_entity.start_node_id, "shared") | |||
| ) and rag_pipeline_variable.variable in inputs: | |||
| rag_pipeline_variables.append( | |||
| RAGPipelineVariableInput( | |||
| variable=rag_pipeline_variable, | |||
| value=inputs[rag_pipeline_variable.variable], | |||
| ) | |||
| ) | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| user_inputs=inputs, | |||
| environment_variables=workflow.environment_variables, | |||
| conversation_variables=[], | |||
| rag_pipeline_variables=rag_pipeline_variables, | |||
| ) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | |||
| # init graph | |||
| graph = self._init_rag_pipeline_graph( | |||
| graph_runtime_state=graph_runtime_state, | |||
| start_node_id=self.application_generate_entity.start_node_id, | |||
| workflow=workflow, | |||
| ) | |||
| # RUN WORKFLOW | |||
| workflow_entry = WorkflowEntry( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=workflow.app_id, | |||
| workflow_id=workflow.id, | |||
| graph=graph, | |||
| graph_config=workflow.graph_dict, | |||
| user_id=self.application_generate_entity.user_id, | |||
| user_from=( | |||
| UserFrom.ACCOUNT | |||
| if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | |||
| else UserFrom.END_USER | |||
| ), | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| call_depth=self.application_generate_entity.call_depth, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| generator = workflow_entry.run() | |||
| for event in generator: | |||
| self._update_document_status( | |||
| event, self.application_generate_entity.document_id, self.application_generate_entity.dataset_id | |||
| ) | |||
| self._handle_event(workflow_entry, event) | |||
| def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]: | |||
| """ | |||
| Get workflow | |||
| """ | |||
| # fetch workflow by workflow_id | |||
| workflow = ( | |||
| db.session.query(Workflow) | |||
| .filter( | |||
| Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id | |||
| ) | |||
| .first() | |||
| ) | |||
| # return workflow | |||
| return workflow | |||
| def _init_rag_pipeline_graph( | |||
| self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None | |||
| ) -> Graph: | |||
| """ | |||
| Init pipeline graph | |||
| """ | |||
| graph_config = workflow.graph_dict | |||
| if "nodes" not in graph_config or "edges" not in graph_config: | |||
| raise ValueError("nodes or edges not found in workflow graph") | |||
| if not isinstance(graph_config.get("nodes"), list): | |||
| raise ValueError("nodes in workflow graph must be a list") | |||
| if not isinstance(graph_config.get("edges"), list): | |||
| raise ValueError("edges in workflow graph must be a list") | |||
| # nodes = graph_config.get("nodes", []) | |||
| # edges = graph_config.get("edges", []) | |||
| # real_run_nodes = [] | |||
| # real_edges = [] | |||
| # exclude_node_ids = [] | |||
| # for node in nodes: | |||
| # node_id = node.get("id") | |||
| # node_type = node.get("data", {}).get("type", "") | |||
| # if node_type == "datasource": | |||
| # if start_node_id != node_id: | |||
| # exclude_node_ids.append(node_id) | |||
| # continue | |||
| # real_run_nodes.append(node) | |||
| # for edge in edges: | |||
| # if edge.get("source") in exclude_node_ids: | |||
| # continue | |||
| # real_edges.append(edge) | |||
| # graph_config = dict(graph_config) | |||
| # graph_config["nodes"] = real_run_nodes | |||
| # graph_config["edges"] = real_edges | |||
| # init graph | |||
| # Create required parameters for Graph.init | |||
| graph_init_params = GraphInitParams( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=self._app_id, | |||
| workflow_id=workflow.id, | |||
| graph_config=graph_config, | |||
| user_id=self.application_generate_entity.user_id, | |||
| user_from=UserFrom.ACCOUNT.value, | |||
| invoke_from=InvokeFrom.SERVICE_API.value, | |||
| call_depth=0, | |||
| ) | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=graph_init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id) | |||
| if not graph: | |||
| raise ValueError("graph not found in workflow") | |||
| return graph | |||
| def _update_document_status(self, event: GraphEngineEvent, document_id: str | None, dataset_id: str | None) -> None: | |||
| """ | |||
| Update document status | |||
| """ | |||
| if isinstance(event, GraphRunFailedEvent): | |||
| if document_id and dataset_id: | |||
| document = ( | |||
| db.session.query(Document) | |||
| .filter(Document.id == document_id, Document.dataset_id == dataset_id) | |||
| .first() | |||
| ) | |||
| if document: | |||
| document.indexing_status = "error" | |||
| document.error = event.error or "Unknown error" | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| @@ -3,7 +3,7 @@ import logging | |||
| import threading | |||
| import uuid | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Any, Literal, Optional, Union, overload | |||
| from typing import Any, Literal, Union, overload | |||
| from flask import Flask, current_app | |||
| from pydantic import ValidationError | |||
| @@ -53,7 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[True], | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Generator[Mapping | str, None, None]: ... | |||
| @overload | |||
| @@ -67,7 +66,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[False], | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Mapping[str, Any]: ... | |||
| @overload | |||
| @@ -81,7 +79,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool, | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... | |||
| def generate( | |||
| @@ -94,7 +91,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| call_depth: int = 0, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: | |||
| files: Sequence[Mapping[str, Any]] = args.get("files") or [] | |||
| @@ -186,7 +182,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow_execution_repository=workflow_execution_repository, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| ) | |||
| def _generate( | |||
| @@ -200,7 +195,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow_execution_repository: WorkflowExecutionRepository, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| streaming: bool = True, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | |||
| ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: | |||
| """ | |||
| @@ -214,7 +208,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| :param workflow_execution_repository: repository for workflow execution | |||
| :param workflow_node_execution_repository: repository for workflow node execution | |||
| :param streaming: is stream | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| # init queue manager | |||
| queue_manager = WorkflowAppQueueManager( | |||
| @@ -237,7 +230,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| "application_generate_entity": application_generate_entity, | |||
| "queue_manager": queue_manager, | |||
| "context": context, | |||
| "workflow_thread_pool_id": workflow_thread_pool_id, | |||
| "variable_loader": variable_loader, | |||
| }, | |||
| ) | |||
| @@ -432,17 +424,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| queue_manager: AppQueueManager, | |||
| context: contextvars.Context, | |||
| variable_loader: VariableLoader, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> None: | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: queue manager | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| :return: | |||
| """ | |||
| with preserve_flask_contexts(flask_app, context_vars=context): | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow = session.scalar( | |||
| @@ -472,7 +454,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| runner = WorkflowAppRunner( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| variable_loader=variable_loader, | |||
| workflow=workflow, | |||
| system_user_id=system_user_id, | |||
| @@ -1,7 +1,7 @@ | |||
| import logging | |||
| from typing import Optional, cast | |||
| import time | |||
| from typing import cast | |||
| from configs import dify_config | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfig | |||
| from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner | |||
| @@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import ( | |||
| InvokeFrom, | |||
| WorkflowAppGenerateEntity, | |||
| ) | |||
| from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities import GraphRuntimeState, VariablePool | |||
| from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import VariableLoader | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_redis import redis_client | |||
| from models.enums import UserFrom | |||
| from models.workflow import Workflow, WorkflowType | |||
| from models.workflow import Workflow | |||
| logger = logging.getLogger(__name__) | |||
| @@ -31,7 +32,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| application_generate_entity: WorkflowAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| variable_loader: VariableLoader, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| workflow: Workflow, | |||
| system_user_id: str, | |||
| ) -> None: | |||
| @@ -41,7 +41,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| ) | |||
| self.application_generate_entity = application_generate_entity | |||
| self.workflow_thread_pool_id = workflow_thread_pool_id | |||
| self._workflow = workflow | |||
| self._sys_user_id = system_user_id | |||
| @@ -52,24 +51,30 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| app_config = self.application_generate_entity.app_config | |||
| app_config = cast(WorkflowAppConfig, app_config) | |||
| workflow_callbacks: list[WorkflowCallback] = [] | |||
| if dify_config.DEBUG: | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| # if only single iteration run is requested | |||
| if self.application_generate_entity.single_iteration_run: | |||
| # if only single iteration run is requested | |||
| graph_runtime_state = GraphRuntimeState( | |||
| variable_pool=VariablePool.empty(), | |||
| start_at=time.time(), | |||
| ) | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | |||
| workflow=self._workflow, | |||
| node_id=self.application_generate_entity.single_iteration_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_iteration_run.inputs, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| elif self.application_generate_entity.single_loop_run: | |||
| # if only single loop run is requested | |||
| graph_runtime_state = GraphRuntimeState( | |||
| variable_pool=VariablePool.empty(), | |||
| start_at=time.time(), | |||
| ) | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( | |||
| workflow=self._workflow, | |||
| node_id=self.application_generate_entity.single_loop_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_loop_run.inputs, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| else: | |||
| inputs = self.application_generate_entity.inputs | |||
| @@ -92,15 +97,26 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| conversation_variables=[], | |||
| ) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | |||
| # init graph | |||
| graph = self._init_graph(graph_config=self._workflow.graph_dict) | |||
| graph = self._init_graph( | |||
| graph_config=self._workflow.graph_dict, | |||
| graph_runtime_state=graph_runtime_state, | |||
| workflow_id=self._workflow.id, | |||
| tenant_id=self._workflow.tenant_id, | |||
| ) | |||
| # RUN WORKFLOW | |||
| # Create Redis command channel for this workflow execution | |||
| task_id = self.application_generate_entity.task_id | |||
| channel_key = f"workflow:{task_id}:commands" | |||
| command_channel = RedisChannel(redis_client, channel_key) | |||
| workflow_entry = WorkflowEntry( | |||
| tenant_id=self._workflow.tenant_id, | |||
| app_id=self._workflow.app_id, | |||
| workflow_id=self._workflow.id, | |||
| workflow_type=WorkflowType.value_of(self._workflow.type), | |||
| graph=graph, | |||
| graph_config=self._workflow.graph_dict, | |||
| user_id=self.application_generate_entity.user_id, | |||
| @@ -111,11 +127,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| ), | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| call_depth=self.application_generate_entity.call_depth, | |||
| variable_pool=variable_pool, | |||
| thread_pool_id=self.workflow_thread_pool_id, | |||
| graph_runtime_state=graph_runtime_state, | |||
| command_channel=command_channel, | |||
| ) | |||
| generator = workflow_entry.run(callbacks=workflow_callbacks) | |||
| generator = workflow_entry.run() | |||
| for event in generator: | |||
| self._handle_event(workflow_entry, event) | |||
| @@ -2,7 +2,7 @@ import logging | |||
| import time | |||
| from collections.abc import Callable, Generator | |||
| from contextlib import contextmanager | |||
| from typing import Any, Optional, Union | |||
| from typing import Optional, Union | |||
| from sqlalchemy.orm import Session | |||
| @@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import ( | |||
| WorkflowAppGenerateEntity, | |||
| ) | |||
| from core.app.entities.queue_entities import ( | |||
| AppQueueEvent, | |||
| MessageQueueMessage, | |||
| QueueAgentLogEvent, | |||
| QueueErrorEvent, | |||
| @@ -25,14 +26,9 @@ from core.app.entities.queue_entities import ( | |||
| QueueLoopStartEvent, | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeInLoopFailedEvent, | |||
| QueueNodeRetryEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| QueueParallelBranchRunStartedEvent, | |||
| QueueParallelBranchRunSucceededEvent, | |||
| QueuePingEvent, | |||
| QueueStopEvent, | |||
| QueueTextChunkEvent, | |||
| @@ -57,8 +53,8 @@ from core.app.entities.task_entities import ( | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.entities import GraphRuntimeState, WorkflowExecution | |||
| from core.workflow.enums import WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| @@ -350,9 +346,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| def _handle_node_failed_events( | |||
| self, | |||
| event: Union[ | |||
| QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent | |||
| ], | |||
| event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent], | |||
| **kwargs, | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| """Handle various node failure events.""" | |||
| @@ -371,32 +365,6 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if node_failed_response: | |||
| yield node_failed_response | |||
| def _handle_parallel_branch_started_event( | |||
| self, event: QueueParallelBranchRunStartedEvent, **kwargs | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| """Handle parallel branch started events.""" | |||
| self._ensure_workflow_initialized() | |||
| parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield parallel_start_resp | |||
| def _handle_parallel_branch_finished_events( | |||
| self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| """Handle parallel branch finished events.""" | |||
| self._ensure_workflow_initialized() | |||
| parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_execution_id=self._workflow_run_id, | |||
| event=event, | |||
| ) | |||
| yield parallel_finish_resp | |||
| def _handle_iteration_start_event( | |||
| self, event: QueueIterationStartEvent, **kwargs | |||
| ) -> Generator[StreamResponse, None, None]: | |||
| @@ -618,8 +586,6 @@ class WorkflowAppGenerateTaskPipeline: | |||
| QueueNodeRetryEvent: self._handle_node_retry_event, | |||
| QueueNodeStartedEvent: self._handle_node_started_event, | |||
| QueueNodeSucceededEvent: self._handle_node_succeeded_event, | |||
| # Parallel branch events | |||
| QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event, | |||
| # Iteration events | |||
| QueueIterationStartEvent: self._handle_iteration_start_event, | |||
| QueueIterationNextEvent: self._handle_iteration_next_event, | |||
| @@ -634,7 +600,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| def _dispatch_event( | |||
| self, | |||
| event: Any, | |||
| event: AppQueueEvent, | |||
| *, | |||
| graph_runtime_state: Optional[GraphRuntimeState] = None, | |||
| tts_publisher: Optional[AppGeneratorTTSPublisher] = None, | |||
| @@ -661,8 +627,6 @@ class WorkflowAppGenerateTaskPipeline: | |||
| event, | |||
| ( | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeInLoopFailedEvent, | |||
| QueueNodeExceptionEvent, | |||
| ), | |||
| ): | |||
| @@ -675,17 +639,6 @@ class WorkflowAppGenerateTaskPipeline: | |||
| ) | |||
| return | |||
| # Handle parallel branch finished events with isinstance check | |||
| if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)): | |||
| yield from self._handle_parallel_branch_finished_events( | |||
| event, | |||
| graph_runtime_state=graph_runtime_state, | |||
| tts_publisher=tts_publisher, | |||
| trace_manager=trace_manager, | |||
| queue_message=queue_message, | |||
| ) | |||
| return | |||
| # Handle workflow failed and stop events with isinstance check | |||
| if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)): | |||
| yield from self._handle_workflow_failed_and_stop_events( | |||
| @@ -2,6 +2,7 @@ from collections.abc import Mapping | |||
| from typing import Any, cast | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.entities.queue_entities import ( | |||
| AppQueueEvent, | |||
| QueueAgentLogEvent, | |||
| @@ -13,14 +14,9 @@ from core.app.entities.queue_entities import ( | |||
| QueueLoopStartEvent, | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeInLoopFailedEvent, | |||
| QueueNodeRetryEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| QueueParallelBranchRunStartedEvent, | |||
| QueueParallelBranchRunSucceededEvent, | |||
| QueueRetrieverResourcesEvent, | |||
| QueueTextChunkEvent, | |||
| QueueWorkflowFailedEvent, | |||
| @@ -28,42 +24,39 @@ from core.app.entities.queue_entities import ( | |||
| QueueWorkflowStartedEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| ) | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| AgentLogEvent, | |||
| from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_events import ( | |||
| GraphEngineEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunPartialSucceededEvent, | |||
| GraphRunStartedEvent, | |||
| GraphRunSucceededEvent, | |||
| IterationRunFailedEvent, | |||
| IterationRunNextEvent, | |||
| IterationRunStartedEvent, | |||
| IterationRunSucceededEvent, | |||
| LoopRunFailedEvent, | |||
| LoopRunNextEvent, | |||
| LoopRunStartedEvent, | |||
| LoopRunSucceededEvent, | |||
| NodeInIterationFailedEvent, | |||
| NodeInLoopFailedEvent, | |||
| NodeRunAgentLogEvent, | |||
| NodeRunExceptionEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunIterationFailedEvent, | |||
| NodeRunIterationNextEvent, | |||
| NodeRunIterationStartedEvent, | |||
| NodeRunIterationSucceededEvent, | |||
| NodeRunLoopFailedEvent, | |||
| NodeRunLoopNextEvent, | |||
| NodeRunLoopStartedEvent, | |||
| NodeRunLoopSucceededEvent, | |||
| NodeRunRetrieverResourceEvent, | |||
| NodeRunRetryEvent, | |||
| NodeRunStartedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| NodeRunSucceededEvent, | |||
| ParallelBranchRunFailedEvent, | |||
| ParallelBranchRunStartedEvent, | |||
| ParallelBranchRunSucceededEvent, | |||
| ) | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_events.graph import GraphRunAbortedEvent | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.node_factory import DifyNodeFactory | |||
| from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from models.enums import UserFrom | |||
| from models.workflow import Workflow | |||
| @@ -79,7 +72,14 @@ class WorkflowBasedAppRunner: | |||
| self._variable_loader = variable_loader | |||
| self._app_id = app_id | |||
| def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: | |||
| def _init_graph( | |||
| self, | |||
| graph_config: Mapping[str, Any], | |||
| graph_runtime_state: GraphRuntimeState, | |||
| workflow_id: str = "", | |||
| tenant_id: str = "", | |||
| user_id: str = "", | |||
| ) -> Graph: | |||
| """ | |||
| Init graph | |||
| """ | |||
| @@ -91,8 +91,28 @@ class WorkflowBasedAppRunner: | |||
| if not isinstance(graph_config.get("edges"), list): | |||
| raise ValueError("edges in workflow graph must be a list") | |||
| # Create required parameters for Graph.init | |||
| graph_init_params = GraphInitParams( | |||
| tenant_id=tenant_id or "", | |||
| app_id=self._app_id, | |||
| workflow_id=workflow_id, | |||
| graph_config=graph_config, | |||
| user_id=user_id, | |||
| user_from=UserFrom.ACCOUNT.value, | |||
| invoke_from=InvokeFrom.SERVICE_API.value, | |||
| call_depth=0, | |||
| ) | |||
| # Use the provided graph_runtime_state for consistent state management | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=graph_init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| # init graph | |||
| graph = Graph.init(graph_config=graph_config) | |||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory) | |||
| if not graph: | |||
| raise ValueError("graph not found in workflow") | |||
| @@ -104,6 +124,7 @@ class WorkflowBasedAppRunner: | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user_inputs: dict, | |||
| graph_runtime_state: GraphRuntimeState, | |||
| ) -> tuple[Graph, VariablePool]: | |||
| """ | |||
| Get variable pool of single iteration | |||
| @@ -145,8 +166,25 @@ class WorkflowBasedAppRunner: | |||
| graph_config["edges"] = edge_configs | |||
| # Create required parameters for Graph.init | |||
| graph_init_params = GraphInitParams( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=self._app_id, | |||
| workflow_id=workflow.id, | |||
| graph_config=graph_config, | |||
| user_id="", | |||
| user_from=UserFrom.ACCOUNT.value, | |||
| invoke_from=InvokeFrom.SERVICE_API.value, | |||
| call_depth=0, | |||
| ) | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=graph_init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| # init graph | |||
| graph = Graph.init(graph_config=graph_config, root_node_id=node_id) | |||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id) | |||
| if not graph: | |||
| raise ValueError("graph not found in workflow") | |||
| @@ -201,6 +239,7 @@ class WorkflowBasedAppRunner: | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user_inputs: dict, | |||
| graph_runtime_state: GraphRuntimeState, | |||
| ) -> tuple[Graph, VariablePool]: | |||
| """ | |||
| Get variable pool of single loop | |||
| @@ -242,8 +281,25 @@ class WorkflowBasedAppRunner: | |||
| graph_config["edges"] = edge_configs | |||
| # Create required parameters for Graph.init | |||
| graph_init_params = GraphInitParams( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=self._app_id, | |||
| workflow_id=workflow.id, | |||
| graph_config=graph_config, | |||
| user_id="", | |||
| user_from=UserFrom.ACCOUNT.value, | |||
| invoke_from=InvokeFrom.SERVICE_API.value, | |||
| call_depth=0, | |||
| ) | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=graph_init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| # init graph | |||
| graph = Graph.init(graph_config=graph_config, root_node_id=node_id) | |||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id) | |||
| if not graph: | |||
| raise ValueError("graph not found in workflow") | |||
| @@ -310,29 +366,21 @@ class WorkflowBasedAppRunner: | |||
| ) | |||
| elif isinstance(event, GraphRunFailedEvent): | |||
| self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) | |||
| elif isinstance(event, GraphRunAbortedEvent): | |||
| self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0)) | |||
| elif isinstance(event, NodeRunRetryEvent): | |||
| node_run_result = event.route_node_state.node_run_result | |||
| inputs: Mapping[str, Any] | None = {} | |||
| process_data: Mapping[str, Any] | None = {} | |||
| outputs: Mapping[str, Any] | None = {} | |||
| execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {} | |||
| if node_run_result: | |||
| inputs = node_run_result.inputs | |||
| process_data = node_run_result.process_data | |||
| outputs = node_run_result.outputs | |||
| execution_metadata = node_run_result.metadata | |||
| node_run_result = event.node_run_result | |||
| inputs = node_run_result.inputs | |||
| process_data = node_run_result.process_data | |||
| outputs = node_run_result.outputs | |||
| execution_metadata = node_run_result.metadata | |||
| self._publish_event( | |||
| QueueNodeRetryEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_title=event.node_title, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.start_at, | |||
| node_run_index=event.route_node_state.index, | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| in_iteration_id=event.in_iteration_id, | |||
| in_loop_id=event.in_loop_id, | |||
| @@ -343,6 +391,8 @@ class WorkflowBasedAppRunner: | |||
| error=event.error, | |||
| execution_metadata=execution_metadata, | |||
| retry_index=event.retry_index, | |||
| provider_type=event.provider_type, | |||
| provider_id=event.provider_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunStartedEvent): | |||
| @@ -350,44 +400,30 @@ class WorkflowBasedAppRunner: | |||
| QueueNodeStartedEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_title=event.node_title, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.route_node_state.start_at, | |||
| node_run_index=event.route_node_state.index, | |||
| start_at=event.start_at, | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| in_iteration_id=event.in_iteration_id, | |||
| in_loop_id=event.in_loop_id, | |||
| parallel_mode_run_id=event.parallel_mode_run_id, | |||
| agent_strategy=event.agent_strategy, | |||
| provider_type=event.provider_type, | |||
| provider_id=event.provider_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunSucceededEvent): | |||
| node_run_result = event.route_node_state.node_run_result | |||
| if node_run_result: | |||
| inputs = node_run_result.inputs | |||
| process_data = node_run_result.process_data | |||
| outputs = node_run_result.outputs | |||
| execution_metadata = node_run_result.metadata | |||
| else: | |||
| inputs = {} | |||
| process_data = {} | |||
| outputs = {} | |||
| execution_metadata = {} | |||
| node_run_result = event.node_run_result | |||
| inputs = node_run_result.inputs | |||
| process_data = node_run_result.process_data | |||
| outputs = node_run_result.outputs | |||
| execution_metadata = node_run_result.metadata | |||
| self._publish_event( | |||
| QueueNodeSucceededEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.route_node_state.start_at, | |||
| start_at=event.start_at, | |||
| inputs=inputs, | |||
| process_data=process_data, | |||
| outputs=outputs, | |||
| @@ -396,34 +432,18 @@ class WorkflowBasedAppRunner: | |||
| in_loop_id=event.in_loop_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunFailedEvent): | |||
| self._publish_event( | |||
| QueueNodeFailedEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.route_node_state.start_at, | |||
| inputs=event.route_node_state.node_run_result.inputs | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| process_data=event.route_node_state.node_run_result.process_data | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| outputs=event.route_node_state.node_run_result.outputs or {} | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| error=event.route_node_state.node_run_result.error | |||
| if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error | |||
| else "Unknown error", | |||
| execution_metadata=event.route_node_state.node_run_result.metadata | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| start_at=event.start_at, | |||
| inputs=event.node_run_result.inputs, | |||
| process_data=event.node_run_result.process_data, | |||
| outputs=event.node_run_result.outputs, | |||
| error=event.node_run_result.error or "Unknown error", | |||
| execution_metadata=event.node_run_result.metadata, | |||
| in_iteration_id=event.in_iteration_id, | |||
| in_loop_id=event.in_loop_id, | |||
| ) | |||
| @@ -434,93 +454,21 @@ class WorkflowBasedAppRunner: | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.route_node_state.start_at, | |||
| inputs=event.route_node_state.node_run_result.inputs | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| process_data=event.route_node_state.node_run_result.process_data | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| outputs=event.route_node_state.node_run_result.outputs | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| error=event.route_node_state.node_run_result.error | |||
| if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error | |||
| else "Unknown error", | |||
| execution_metadata=event.route_node_state.node_run_result.metadata | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| start_at=event.start_at, | |||
| inputs=event.node_run_result.inputs, | |||
| process_data=event.node_run_result.process_data, | |||
| outputs=event.node_run_result.outputs, | |||
| error=event.node_run_result.error or "Unknown error", | |||
| execution_metadata=event.node_run_result.metadata, | |||
| in_iteration_id=event.in_iteration_id, | |||
| in_loop_id=event.in_loop_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeInIterationFailedEvent): | |||
| self._publish_event( | |||
| QueueNodeInIterationFailedEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.route_node_state.start_at, | |||
| inputs=event.route_node_state.node_run_result.inputs | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| process_data=event.route_node_state.node_run_result.process_data | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| outputs=event.route_node_state.node_run_result.outputs or {} | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| execution_metadata=event.route_node_state.node_run_result.metadata | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| in_iteration_id=event.in_iteration_id, | |||
| error=event.error, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeInLoopFailedEvent): | |||
| self._publish_event( | |||
| QueueNodeInLoopFailedEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.route_node_state.start_at, | |||
| inputs=event.route_node_state.node_run_result.inputs | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| process_data=event.route_node_state.node_run_result.process_data | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| outputs=event.route_node_state.node_run_result.outputs or {} | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| execution_metadata=event.route_node_state.node_run_result.metadata | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| in_loop_id=event.in_loop_id, | |||
| error=event.error, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunStreamChunkEvent): | |||
| self._publish_event( | |||
| QueueTextChunkEvent( | |||
| text=event.chunk_content, | |||
| from_variable_selector=event.from_variable_selector, | |||
| text=event.chunk, | |||
| from_variable_selector=list(event.selector), | |||
| in_iteration_id=event.in_iteration_id, | |||
| in_loop_id=event.in_loop_id, | |||
| ) | |||
| @@ -533,10 +481,10 @@ class WorkflowBasedAppRunner: | |||
| in_loop_id=event.in_loop_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, AgentLogEvent): | |||
| elif isinstance(event, NodeRunAgentLogEvent): | |||
| self._publish_event( | |||
| QueueAgentLogEvent( | |||
| id=event.id, | |||
| id=event.message_id, | |||
| label=event.label, | |||
| node_execution_id=event.node_execution_id, | |||
| parent_id=event.parent_id, | |||
| @@ -547,51 +495,13 @@ class WorkflowBasedAppRunner: | |||
| node_id=event.node_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, ParallelBranchRunStartedEvent): | |||
| self._publish_event( | |||
| QueueParallelBranchRunStartedEvent( | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| in_iteration_id=event.in_iteration_id, | |||
| in_loop_id=event.in_loop_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, ParallelBranchRunSucceededEvent): | |||
| self._publish_event( | |||
| QueueParallelBranchRunSucceededEvent( | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| in_iteration_id=event.in_iteration_id, | |||
| in_loop_id=event.in_loop_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, ParallelBranchRunFailedEvent): | |||
| self._publish_event( | |||
| QueueParallelBranchRunFailedEvent( | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| in_iteration_id=event.in_iteration_id, | |||
| in_loop_id=event.in_loop_id, | |||
| error=event.error, | |||
| ) | |||
| ) | |||
| elif isinstance(event, IterationRunStartedEvent): | |||
| elif isinstance(event, NodeRunIterationStartedEvent): | |||
| self._publish_event( | |||
| QueueIterationStartEvent( | |||
| node_execution_id=event.iteration_id, | |||
| node_id=event.iteration_node_id, | |||
| node_type=event.iteration_node_type, | |||
| node_data=event.iteration_node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_title=event.node_title, | |||
| start_at=event.start_at, | |||
| node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, | |||
| inputs=event.inputs, | |||
| @@ -599,55 +509,41 @@ class WorkflowBasedAppRunner: | |||
| metadata=event.metadata, | |||
| ) | |||
| ) | |||
| elif isinstance(event, IterationRunNextEvent): | |||
| elif isinstance(event, NodeRunIterationNextEvent): | |||
| self._publish_event( | |||
| QueueIterationNextEvent( | |||
| node_execution_id=event.iteration_id, | |||
| node_id=event.iteration_node_id, | |||
| node_type=event.iteration_node_type, | |||
| node_data=event.iteration_node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_title=event.node_title, | |||
| index=event.index, | |||
| node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, | |||
| output=event.pre_iteration_output, | |||
| parallel_mode_run_id=event.parallel_mode_run_id, | |||
| duration=event.duration, | |||
| ) | |||
| ) | |||
| elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): | |||
| elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)): | |||
| self._publish_event( | |||
| QueueIterationCompletedEvent( | |||
| node_execution_id=event.iteration_id, | |||
| node_id=event.iteration_node_id, | |||
| node_type=event.iteration_node_type, | |||
| node_data=event.iteration_node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_title=event.node_title, | |||
| start_at=event.start_at, | |||
| node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, | |||
| inputs=event.inputs, | |||
| outputs=event.outputs, | |||
| metadata=event.metadata, | |||
| steps=event.steps, | |||
| error=event.error if isinstance(event, IterationRunFailedEvent) else None, | |||
| error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None, | |||
| ) | |||
| ) | |||
| elif isinstance(event, LoopRunStartedEvent): | |||
| elif isinstance(event, NodeRunLoopStartedEvent): | |||
| self._publish_event( | |||
| QueueLoopStartEvent( | |||
| node_execution_id=event.loop_id, | |||
| node_id=event.loop_node_id, | |||
| node_type=event.loop_node_type, | |||
| node_data=event.loop_node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_title=event.node_title, | |||
| start_at=event.start_at, | |||
| node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, | |||
| inputs=event.inputs, | |||
| @@ -655,42 +551,32 @@ class WorkflowBasedAppRunner: | |||
| metadata=event.metadata, | |||
| ) | |||
| ) | |||
| elif isinstance(event, LoopRunNextEvent): | |||
| elif isinstance(event, NodeRunLoopNextEvent): | |||
| self._publish_event( | |||
| QueueLoopNextEvent( | |||
| node_execution_id=event.loop_id, | |||
| node_id=event.loop_node_id, | |||
| node_type=event.loop_node_type, | |||
| node_data=event.loop_node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_title=event.node_title, | |||
| index=event.index, | |||
| node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, | |||
| output=event.pre_loop_output, | |||
| parallel_mode_run_id=event.parallel_mode_run_id, | |||
| duration=event.duration, | |||
| ) | |||
| ) | |||
| elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)): | |||
| elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)): | |||
| self._publish_event( | |||
| QueueLoopCompletedEvent( | |||
| node_execution_id=event.loop_id, | |||
| node_id=event.loop_node_id, | |||
| node_type=event.loop_node_type, | |||
| node_data=event.loop_node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_title=event.node_title, | |||
| start_at=event.start_at, | |||
| node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, | |||
| inputs=event.inputs, | |||
| outputs=event.outputs, | |||
| metadata=event.metadata, | |||
| steps=event.steps, | |||
| error=event.error if isinstance(event, LoopRunFailedEvent) else None, | |||
| error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None, | |||
| ) | |||
| ) | |||
| @@ -1,5 +1,5 @@ | |||
| from collections.abc import Mapping, Sequence | |||
| from enum import Enum | |||
| from enum import StrEnum | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator | |||
| @@ -11,7 +11,7 @@ from core.file import File, FileUploadConfig | |||
| from core.model_runtime.entities.model_entities import AIModelEntity | |||
| class InvokeFrom(Enum): | |||
| class InvokeFrom(StrEnum): | |||
| """ | |||
| Invoke From. | |||
| """ | |||
| @@ -35,6 +35,7 @@ class InvokeFrom(Enum): | |||
| # DEBUGGER indicates that this invocation is from | |||
| # the workflow (or chatflow) edit page. | |||
| DEBUGGER = "debugger" | |||
| PUBLISHED = "published" | |||
| @classmethod | |||
| def value_of(cls, value: str): | |||
| @@ -240,3 +241,38 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): | |||
| inputs: dict | |||
| single_loop_run: Optional[SingleLoopRunEntity] = None | |||
| class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): | |||
| """ | |||
| RAG Pipeline Application Generate Entity. | |||
| """ | |||
| # pipeline config | |||
| pipeline_config: WorkflowUIBasedAppConfig | |||
| datasource_type: str | |||
| datasource_info: Mapping[str, Any] | |||
| dataset_id: str | |||
| batch: str | |||
| document_id: Optional[str] = None | |||
| start_node_id: Optional[str] = None | |||
| class SingleIterationRunEntity(BaseModel): | |||
| """ | |||
| Single Iteration Run Entity. | |||
| """ | |||
| node_id: str | |||
| inputs: dict | |||
| single_iteration_run: Optional[SingleIterationRunEntity] = None | |||
| class SingleLoopRunEntity(BaseModel): | |||
| """ | |||
| Single Loop Run Entity. | |||
| """ | |||
| node_id: str | |||
| inputs: dict | |||
| single_loop_run: Optional[SingleLoopRunEntity] = None | |||
| @@ -7,11 +7,9 @@ from pydantic import BaseModel | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState | |||
| from core.workflow.enums import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.base import BaseNodeData | |||
| class QueueEvent(StrEnum): | |||
| @@ -43,9 +41,6 @@ class QueueEvent(StrEnum): | |||
| ANNOTATION_REPLY = "annotation_reply" | |||
| AGENT_THOUGHT = "agent_thought" | |||
| MESSAGE_FILE = "message_file" | |||
| PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" | |||
| PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" | |||
| PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" | |||
| AGENT_LOG = "agent_log" | |||
| ERROR = "error" | |||
| PING = "ping" | |||
| @@ -80,15 +75,7 @@ class QueueIterationStartEvent(AppQueueEvent): | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| node_title: str | |||
| start_at: datetime | |||
| node_run_index: int | |||
| @@ -108,20 +95,9 @@ class QueueIterationNextEvent(AppQueueEvent): | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| parallel_mode_run_id: Optional[str] = None | |||
| """iteration run in parallel mode run id""" | |||
| node_title: str | |||
| node_run_index: int | |||
| output: Optional[Any] = None # output for the current iteration | |||
| duration: Optional[float] = None | |||
| class QueueIterationCompletedEvent(AppQueueEvent): | |||
| @@ -134,15 +110,7 @@ class QueueIterationCompletedEvent(AppQueueEvent): | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| node_title: str | |||
| start_at: datetime | |||
| node_run_index: int | |||
| @@ -163,7 +131,7 @@ class QueueLoopStartEvent(AppQueueEvent): | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| node_title: str | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| @@ -191,7 +159,7 @@ class QueueLoopNextEvent(AppQueueEvent): | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| node_title: str | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| @@ -204,7 +172,6 @@ class QueueLoopNextEvent(AppQueueEvent): | |||
| """iteration run in parallel mode run id""" | |||
| node_run_index: int | |||
| output: Optional[Any] = None # output for the current loop | |||
| duration: Optional[float] = None | |||
| class QueueLoopCompletedEvent(AppQueueEvent): | |||
| @@ -217,7 +184,7 @@ class QueueLoopCompletedEvent(AppQueueEvent): | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| node_title: str | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| @@ -364,27 +331,24 @@ class QueueNodeStartedEvent(AppQueueEvent): | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_title: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| node_run_index: int = 1 | |||
| node_run_index: int = 1 # FIXME(-LAN-): may not used | |||
| predecessor_node_id: Optional[str] = None | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| in_loop_id: Optional[str] = None | |||
| """loop id if node is in loop""" | |||
| start_at: datetime | |||
| parallel_mode_run_id: Optional[str] = None | |||
| """iteration run in parallel mode run id""" | |||
| agent_strategy: Optional[AgentNodeStrategyInit] = None | |||
| # FIXME(-LAN-): only for ToolNode, need to refactor | |||
| provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType | |||
| provider_id: str | |||
| class QueueNodeSucceededEvent(AppQueueEvent): | |||
| """ | |||
| @@ -396,7 +360,6 @@ class QueueNodeSucceededEvent(AppQueueEvent): | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| @@ -417,10 +380,6 @@ class QueueNodeSucceededEvent(AppQueueEvent): | |||
| execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None | |||
| error: Optional[str] = None | |||
| """single iteration duration map""" | |||
| iteration_duration_map: Optional[dict[str, float]] = None | |||
| """single loop duration map""" | |||
| loop_duration_map: Optional[dict[str, float]] = None | |||
| class QueueAgentLogEvent(AppQueueEvent): | |||
| @@ -454,72 +413,6 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent): | |||
| retry_index: int # retry index | |||
| class QueueNodeInIterationFailedEvent(AppQueueEvent): | |||
| """ | |||
| QueueNodeInIterationFailedEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.NODE_FAILED | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| in_loop_id: Optional[str] = None | |||
| """loop id if node is in loop""" | |||
| start_at: datetime | |||
| inputs: Optional[Mapping[str, Any]] = None | |||
| process_data: Optional[Mapping[str, Any]] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None | |||
| error: str | |||
| class QueueNodeInLoopFailedEvent(AppQueueEvent): | |||
| """ | |||
| QueueNodeInLoopFailedEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.NODE_FAILED | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| in_loop_id: Optional[str] = None | |||
| """loop id if node is in loop""" | |||
| start_at: datetime | |||
| inputs: Optional[Mapping[str, Any]] = None | |||
| process_data: Optional[Mapping[str, Any]] = None | |||
| outputs: Optional[Mapping[str, Any]] = None | |||
| execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None | |||
| error: str | |||
| class QueueNodeExceptionEvent(AppQueueEvent): | |||
| """ | |||
| QueueNodeExceptionEvent entity | |||
| @@ -530,7 +423,6 @@ class QueueNodeExceptionEvent(AppQueueEvent): | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| @@ -563,15 +455,7 @@ class QueueNodeFailedEvent(AppQueueEvent): | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| in_loop_id: Optional[str] = None | |||
| @@ -678,61 +562,3 @@ class WorkflowQueueMessage(QueueMessage): | |||
| """ | |||
| pass | |||
| class QueueParallelBranchRunStartedEvent(AppQueueEvent): | |||
| """ | |||
| QueueParallelBranchRunStartedEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED | |||
| parallel_id: str | |||
| parallel_start_node_id: str | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| in_loop_id: Optional[str] = None | |||
| """loop id if node is in loop""" | |||
| class QueueParallelBranchRunSucceededEvent(AppQueueEvent): | |||
| """ | |||
| QueueParallelBranchRunSucceededEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED | |||
| parallel_id: str | |||
| parallel_start_node_id: str | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| in_loop_id: Optional[str] = None | |||
| """loop id if node is in loop""" | |||
| class QueueParallelBranchRunFailedEvent(AppQueueEvent): | |||
| """ | |||
| QueueParallelBranchRunFailedEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED | |||
| parallel_id: str | |||
| parallel_start_node_id: str | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| in_loop_id: Optional[str] = None | |||
| """loop id if node is in loop""" | |||
| error: str | |||
| @@ -7,8 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| from core.workflow.entities import AgentNodeStrategyInit | |||
| from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| class AnnotationReplyAccount(BaseModel): | |||
| @@ -71,8 +71,6 @@ class StreamEvent(Enum): | |||
| NODE_STARTED = "node_started" | |||
| NODE_FINISHED = "node_finished" | |||
| NODE_RETRY = "node_retry" | |||
| PARALLEL_BRANCH_STARTED = "parallel_branch_started" | |||
| PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" | |||
| ITERATION_STARTED = "iteration_started" | |||
| ITERATION_NEXT = "iteration_next" | |||
| ITERATION_COMPLETED = "iteration_completed" | |||
| @@ -447,54 +445,6 @@ class NodeRetryStreamResponse(StreamResponse): | |||
| } | |||
| class ParallelBranchStartStreamResponse(StreamResponse): | |||
| """ | |||
| ParallelBranchStartStreamResponse entity | |||
| """ | |||
| class Data(BaseModel): | |||
| """ | |||
| Data entity | |||
| """ | |||
| parallel_id: str | |||
| parallel_branch_id: str | |||
| parent_parallel_id: Optional[str] = None | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| iteration_id: Optional[str] = None | |||
| loop_id: Optional[str] = None | |||
| created_at: int | |||
| event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED | |||
| workflow_run_id: str | |||
| data: Data | |||
| class ParallelBranchFinishedStreamResponse(StreamResponse): | |||
| """ | |||
| ParallelBranchFinishedStreamResponse entity | |||
| """ | |||
| class Data(BaseModel): | |||
| """ | |||
| Data entity | |||
| """ | |||
| parallel_id: str | |||
| parallel_branch_id: str | |||
| parent_parallel_id: Optional[str] = None | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| iteration_id: Optional[str] = None | |||
| loop_id: Optional[str] = None | |||
| status: str | |||
| error: Optional[str] = None | |||
| created_at: int | |||
| event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED | |||
| workflow_run_id: str | |||
| data: Data | |||
| class IterationNodeStartStreamResponse(StreamResponse): | |||
| """ | |||
| NodeStartStreamResponse entity | |||
| @@ -514,8 +464,6 @@ class IterationNodeStartStreamResponse(StreamResponse): | |||
| metadata: Mapping = {} | |||
| inputs: Mapping = {} | |||
| inputs_truncated: bool = False | |||
| parallel_id: Optional[str] = None | |||
| parallel_start_node_id: Optional[str] = None | |||
| event: StreamEvent = StreamEvent.ITERATION_STARTED | |||
| workflow_run_id: str | |||
| @@ -538,12 +486,7 @@ class IterationNodeNextStreamResponse(StreamResponse): | |||
| title: str | |||
| index: int | |||
| created_at: int | |||
| pre_iteration_output: Optional[Any] = None | |||
| extras: dict = Field(default_factory=dict) | |||
| parallel_id: Optional[str] = None | |||
| parallel_start_node_id: Optional[str] = None | |||
| parallel_mode_run_id: Optional[str] = None | |||
| duration: Optional[float] = None | |||
| event: StreamEvent = StreamEvent.ITERATION_NEXT | |||
| workflow_run_id: str | |||
| @@ -577,8 +520,6 @@ class IterationNodeCompletedStreamResponse(StreamResponse): | |||
| execution_metadata: Optional[Mapping] = None | |||
| finished_at: int | |||
| steps: int | |||
| parallel_id: Optional[str] = None | |||
| parallel_start_node_id: Optional[str] = None | |||
| event: StreamEvent = StreamEvent.ITERATION_COMPLETED | |||
| workflow_run_id: str | |||
| @@ -633,7 +574,6 @@ class LoopNodeNextStreamResponse(StreamResponse): | |||
| parallel_id: Optional[str] = None | |||
| parallel_start_node_id: Optional[str] = None | |||
| parallel_mode_run_id: Optional[str] = None | |||
| duration: Optional[float] = None | |||
| event: StreamEvent = StreamEvent.LOOP_NEXT | |||
| workflow_run_id: str | |||
| @@ -105,6 +105,14 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| self.current_loop += 1 | |||
| def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None: | |||
| """Run on datasource start.""" | |||
| if dify_config.DEBUG: | |||
| print_text( | |||
| "\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n", | |||
| color=self.color, | |||
| ) | |||
| @property | |||
| def ignore_agent(self) -> bool: | |||
| """Whether to ignore agent callbacks.""" | |||
| @@ -0,0 +1,33 @@ | |||
| from abc import ABC, abstractmethod | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import ( | |||
| DatasourceEntity, | |||
| DatasourceProviderType, | |||
| ) | |||
| class DatasourcePlugin(ABC): | |||
| entity: DatasourceEntity | |||
| runtime: DatasourceRuntime | |||
| def __init__( | |||
| self, | |||
| entity: DatasourceEntity, | |||
| runtime: DatasourceRuntime, | |||
| ) -> None: | |||
| self.entity = entity | |||
| self.runtime = runtime | |||
| @abstractmethod | |||
| def datasource_provider_type(self) -> str: | |||
| """ | |||
| returns the type of the datasource provider | |||
| """ | |||
| return DatasourceProviderType.LOCAL_FILE | |||
| def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": | |||
| return self.__class__( | |||
| entity=self.entity.model_copy(), | |||
| runtime=runtime, | |||
| ) | |||
| @@ -0,0 +1,118 @@ | |||
| from abc import ABC, abstractmethod | |||
| from typing import Any | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType | |||
| from core.entities.provider_entities import ProviderConfig | |||
| from core.plugin.impl.tool import PluginToolManager | |||
| from core.tools.errors import ToolProviderCredentialValidationError | |||
| class DatasourcePluginProviderController(ABC): | |||
| entity: DatasourceProviderEntityWithPlugin | |||
| tenant_id: str | |||
| def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None: | |||
| self.entity = entity | |||
| self.tenant_id = tenant_id | |||
| @property | |||
| def need_credentials(self) -> bool: | |||
| """ | |||
| returns whether the provider needs credentials | |||
| :return: whether the provider needs credentials | |||
| """ | |||
| return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 | |||
| def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: | |||
| """ | |||
| validate the credentials of the provider | |||
| """ | |||
| manager = PluginToolManager() | |||
| if not manager.validate_datasource_credentials( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| provider=self.entity.identity.name, | |||
| credentials=credentials, | |||
| ): | |||
| raise ToolProviderCredentialValidationError("Invalid credentials") | |||
| @property | |||
| def provider_type(self) -> DatasourceProviderType: | |||
| """ | |||
| returns the type of the provider | |||
| """ | |||
| return DatasourceProviderType.LOCAL_FILE | |||
| @abstractmethod | |||
| def get_datasource(self, datasource_name: str) -> DatasourcePlugin: | |||
| """ | |||
| return datasource with given name | |||
| """ | |||
| pass | |||
| def validate_credentials_format(self, credentials: dict[str, Any]) -> None: | |||
| """ | |||
| validate the format of the credentials of the provider and set the default value if needed | |||
| :param credentials: the credentials of the tool | |||
| """ | |||
| credentials_schema = dict[str, ProviderConfig]() | |||
| if credentials_schema is None: | |||
| return | |||
| for credential in self.entity.credentials_schema: | |||
| credentials_schema[credential.name] = credential | |||
| credentials_need_to_validate: dict[str, ProviderConfig] = {} | |||
| for credential_name in credentials_schema: | |||
| credentials_need_to_validate[credential_name] = credentials_schema[credential_name] | |||
| for credential_name in credentials: | |||
| if credential_name not in credentials_need_to_validate: | |||
| raise ToolProviderCredentialValidationError( | |||
| f"credential {credential_name} not found in provider {self.entity.identity.name}" | |||
| ) | |||
| # check type | |||
| credential_schema = credentials_need_to_validate[credential_name] | |||
| if not credential_schema.required and credentials[credential_name] is None: | |||
| continue | |||
| if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}: | |||
| if not isinstance(credentials[credential_name], str): | |||
| raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") | |||
| elif credential_schema.type == ProviderConfig.Type.SELECT: | |||
| if not isinstance(credentials[credential_name], str): | |||
| raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") | |||
| options = credential_schema.options | |||
| if not isinstance(options, list): | |||
| raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") | |||
| if credentials[credential_name] not in [x.value for x in options]: | |||
| raise ToolProviderCredentialValidationError( | |||
| f"credential {credential_name} should be one of {options}" | |||
| ) | |||
| credentials_need_to_validate.pop(credential_name) | |||
| for credential_name in credentials_need_to_validate: | |||
| credential_schema = credentials_need_to_validate[credential_name] | |||
| if credential_schema.required: | |||
| raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") | |||
| # the credential is not set currently, set the default value if needed | |||
| if credential_schema.default is not None: | |||
| default_value = credential_schema.default | |||
| # parse default value into the correct type | |||
| if credential_schema.type in { | |||
| ProviderConfig.Type.SECRET_INPUT, | |||
| ProviderConfig.Type.TEXT_INPUT, | |||
| ProviderConfig.Type.SELECT, | |||
| }: | |||
| default_value = str(default_value) | |||
| credentials[credential_name] = default_value | |||
| @@ -0,0 +1,36 @@ | |||
| from typing import Any, Optional | |||
| from openai import BaseModel | |||
| from pydantic import Field | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.datasource.entities.datasource_entities import DatasourceInvokeFrom | |||
| class DatasourceRuntime(BaseModel): | |||
| """ | |||
| Meta data of a datasource call processing | |||
| """ | |||
| tenant_id: str | |||
| datasource_id: Optional[str] = None | |||
| invoke_from: Optional[InvokeFrom] = None | |||
| datasource_invoke_from: Optional[DatasourceInvokeFrom] = None | |||
| credentials: dict[str, Any] = Field(default_factory=dict) | |||
| runtime_parameters: dict[str, Any] = Field(default_factory=dict) | |||
| class FakeDatasourceRuntime(DatasourceRuntime): | |||
| """ | |||
| Fake datasource runtime for testing | |||
| """ | |||
| def __init__(self): | |||
| super().__init__( | |||
| tenant_id="fake_tenant_id", | |||
| datasource_id="fake_datasource_id", | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE, | |||
| credentials={}, | |||
| runtime_parameters={}, | |||
| ) | |||
| @@ -0,0 +1,247 @@ | |||
| import base64 | |||
| import hashlib | |||
| import hmac | |||
| import logging | |||
| import os | |||
| import time | |||
| from datetime import datetime | |||
| from mimetypes import guess_extension, guess_type | |||
| from typing import Optional, Union | |||
| from uuid import uuid4 | |||
| import httpx | |||
| from configs import dify_config | |||
| from core.helper import ssrf_proxy | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from models.enums import CreatorUserRole | |||
| from models.model import MessageFile, UploadFile | |||
| from models.tools import ToolFile | |||
| logger = logging.getLogger(__name__) | |||
| class DatasourceFileManager: | |||
| @staticmethod | |||
| def sign_file(datasource_file_id: str, extension: str) -> str: | |||
| """ | |||
| sign file to get a temporary url | |||
| """ | |||
| base_url = dify_config.FILES_URL | |||
| file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}" | |||
| timestamp = str(int(time.time())) | |||
| nonce = os.urandom(16).hex() | |||
| data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" | |||
| secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" | |||
| sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |||
| return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" | |||
| @staticmethod | |||
| def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: | |||
| """ | |||
| verify signature | |||
| """ | |||
| data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" | |||
| secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" | |||
| recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||
| recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() | |||
| # verify signature | |||
| if sign != recalculated_encoded_sign: | |||
| return False | |||
| current_time = int(time.time()) | |||
| return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT | |||
| @staticmethod | |||
| def create_file_by_raw( | |||
| *, | |||
| user_id: str, | |||
| tenant_id: str, | |||
| conversation_id: Optional[str], | |||
| file_binary: bytes, | |||
| mimetype: str, | |||
| filename: Optional[str] = None, | |||
| ) -> UploadFile: | |||
| extension = guess_extension(mimetype) or ".bin" | |||
| unique_name = uuid4().hex | |||
| unique_filename = f"{unique_name}{extension}" | |||
| # default just as before | |||
| present_filename = unique_filename | |||
| if filename is not None: | |||
| has_extension = len(filename.split(".")) > 1 | |||
| # Add extension flexibly | |||
| present_filename = filename if has_extension else f"{filename}{extension}" | |||
| filepath = f"datasources/{tenant_id}/{unique_filename}" | |||
| storage.save(filepath, file_binary) | |||
| upload_file = UploadFile( | |||
| tenant_id=tenant_id, | |||
| storage_type=dify_config.STORAGE_TYPE, | |||
| key=filepath, | |||
| name=present_filename, | |||
| size=len(file_binary), | |||
| extension=extension, | |||
| mime_type=mimetype, | |||
| created_by_role=CreatorUserRole.ACCOUNT, | |||
| created_by=user_id, | |||
| used=False, | |||
| hash=hashlib.sha3_256(file_binary).hexdigest(), | |||
| source_url="", | |||
| created_at=datetime.now(), | |||
| ) | |||
| db.session.add(upload_file) | |||
| db.session.commit() | |||
| db.session.refresh(upload_file) | |||
| return upload_file | |||
| @staticmethod | |||
| def create_file_by_url( | |||
| user_id: str, | |||
| tenant_id: str, | |||
| file_url: str, | |||
| conversation_id: Optional[str] = None, | |||
| ) -> UploadFile: | |||
| # try to download image | |||
| try: | |||
| response = ssrf_proxy.get(file_url) | |||
| response.raise_for_status() | |||
| blob = response.content | |||
| except httpx.TimeoutException: | |||
| raise ValueError(f"timeout when downloading file from {file_url}") | |||
| mimetype = ( | |||
| guess_type(file_url)[0] | |||
| or response.headers.get("Content-Type", "").split(";")[0].strip() | |||
| or "application/octet-stream" | |||
| ) | |||
| extension = guess_extension(mimetype) or ".bin" | |||
| unique_name = uuid4().hex | |||
| filename = f"{unique_name}{extension}" | |||
| filepath = f"tools/{tenant_id}/{filename}" | |||
| storage.save(filepath, blob) | |||
| upload_file = UploadFile( | |||
| tenant_id=tenant_id, | |||
| storage_type=dify_config.STORAGE_TYPE, | |||
| key=filepath, | |||
| name=filename, | |||
| size=len(blob), | |||
| extension=extension, | |||
| mime_type=mimetype, | |||
| created_by_role=CreatorUserRole.ACCOUNT, | |||
| created_by=user_id, | |||
| used=False, | |||
| hash=hashlib.sha3_256(blob).hexdigest(), | |||
| source_url=file_url, | |||
| created_at=datetime.now(), | |||
| ) | |||
| db.session.add(upload_file) | |||
| db.session.commit() | |||
| return upload_file | |||
| @staticmethod | |||
| def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: | |||
| """ | |||
| get file binary | |||
| :param id: the id of the file | |||
| :return: the binary of the file, mime type | |||
| """ | |||
| upload_file: UploadFile | None = ( | |||
| db.session.query(UploadFile) | |||
| .filter( | |||
| UploadFile.id == id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not upload_file: | |||
| return None | |||
| blob = storage.load_once(upload_file.key) | |||
| return blob, upload_file.mime_type | |||
| @staticmethod | |||
| def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]: | |||
| """ | |||
| get file binary | |||
| :param id: the id of the file | |||
| :return: the binary of the file, mime type | |||
| """ | |||
| message_file: MessageFile | None = ( | |||
| db.session.query(MessageFile) | |||
| .filter( | |||
| MessageFile.id == id, | |||
| ) | |||
| .first() | |||
| ) | |||
| # Check if message_file is not None | |||
| if message_file is not None: | |||
| # get tool file id | |||
| if message_file.url is not None: | |||
| tool_file_id = message_file.url.split("/")[-1] | |||
| # trim extension | |||
| tool_file_id = tool_file_id.split(".")[0] | |||
| else: | |||
| tool_file_id = None | |||
| else: | |||
| tool_file_id = None | |||
| tool_file: ToolFile | None = ( | |||
| db.session.query(ToolFile) | |||
| .filter( | |||
| ToolFile.id == tool_file_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not tool_file: | |||
| return None | |||
| blob = storage.load_once(tool_file.file_key) | |||
| return blob, tool_file.mimetype | |||
| @staticmethod | |||
| def get_file_generator_by_upload_file_id(upload_file_id: str): | |||
| """ | |||
| get file binary | |||
| :param tool_file_id: the id of the tool file | |||
| :return: the binary of the file, mime type | |||
| """ | |||
| upload_file: UploadFile | None = ( | |||
| db.session.query(UploadFile) | |||
| .filter( | |||
| UploadFile.id == upload_file_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not upload_file: | |||
| return None, None | |||
| stream = storage.load_stream(upload_file.key) | |||
| return stream, upload_file.mime_type | |||
| # init tool_file_parser | |||
| # from core.file.datasource_file_parser import datasource_file_manager | |||
| # | |||
| # datasource_file_manager["manager"] = DatasourceFileManager | |||
| @@ -0,0 +1,108 @@ | |||
| import logging | |||
| from threading import Lock | |||
| from typing import Union | |||
| import contexts | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.__base.datasource_provider import DatasourcePluginProviderController | |||
| from core.datasource.entities.common_entities import I18nObject | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderType | |||
| from core.datasource.errors import DatasourceProviderNotFoundError | |||
| from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController | |||
| from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController | |||
| from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController | |||
| from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController | |||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||
| logger = logging.getLogger(__name__) | |||
| class DatasourceManager: | |||
| _builtin_provider_lock = Lock() | |||
| _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {} | |||
| _builtin_providers_loaded = False | |||
| _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} | |||
| @classmethod | |||
| def get_datasource_plugin_provider( | |||
| cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType | |||
| ) -> DatasourcePluginProviderController: | |||
| """ | |||
| get the datasource plugin provider | |||
| """ | |||
| # check if context is set | |||
| try: | |||
| contexts.datasource_plugin_providers.get() | |||
| except LookupError: | |||
| contexts.datasource_plugin_providers.set({}) | |||
| contexts.datasource_plugin_providers_lock.set(Lock()) | |||
| with contexts.datasource_plugin_providers_lock.get(): | |||
| datasource_plugin_providers = contexts.datasource_plugin_providers.get() | |||
| if provider_id in datasource_plugin_providers: | |||
| return datasource_plugin_providers[provider_id] | |||
| manager = PluginDatasourceManager() | |||
| provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id) | |||
| if not provider_entity: | |||
| raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found") | |||
| match datasource_type: | |||
| case DatasourceProviderType.ONLINE_DOCUMENT: | |||
| controller = OnlineDocumentDatasourcePluginProviderController( | |||
| entity=provider_entity.declaration, | |||
| plugin_id=provider_entity.plugin_id, | |||
| plugin_unique_identifier=provider_entity.plugin_unique_identifier, | |||
| tenant_id=tenant_id, | |||
| ) | |||
| case DatasourceProviderType.ONLINE_DRIVE: | |||
| controller = OnlineDriveDatasourcePluginProviderController( | |||
| entity=provider_entity.declaration, | |||
| plugin_id=provider_entity.plugin_id, | |||
| plugin_unique_identifier=provider_entity.plugin_unique_identifier, | |||
| tenant_id=tenant_id, | |||
| ) | |||
| case DatasourceProviderType.WEBSITE_CRAWL: | |||
| controller = WebsiteCrawlDatasourcePluginProviderController( | |||
| entity=provider_entity.declaration, | |||
| plugin_id=provider_entity.plugin_id, | |||
| plugin_unique_identifier=provider_entity.plugin_unique_identifier, | |||
| tenant_id=tenant_id, | |||
| ) | |||
| case DatasourceProviderType.LOCAL_FILE: | |||
| controller = LocalFileDatasourcePluginProviderController( | |||
| entity=provider_entity.declaration, | |||
| plugin_id=provider_entity.plugin_id, | |||
| plugin_unique_identifier=provider_entity.plugin_unique_identifier, | |||
| tenant_id=tenant_id, | |||
| ) | |||
| case _: | |||
| raise ValueError(f"Unsupported datasource type: {datasource_type}") | |||
| datasource_plugin_providers[provider_id] = controller | |||
| return controller | |||
| @classmethod | |||
| def get_datasource_runtime( | |||
| cls, | |||
| provider_id: str, | |||
| datasource_name: str, | |||
| tenant_id: str, | |||
| datasource_type: DatasourceProviderType, | |||
| ) -> DatasourcePlugin: | |||
| """ | |||
| get the datasource runtime | |||
| :param provider_type: the type of the provider | |||
| :param provider_id: the id of the provider | |||
| :param datasource_name: the name of the datasource | |||
| :param tenant_id: the tenant id | |||
| :return: the datasource plugin | |||
| """ | |||
| return cls.get_datasource_plugin_provider( | |||
| provider_id, | |||
| tenant_id, | |||
| datasource_type, | |||
| ).get_datasource(datasource_name) | |||
| @@ -0,0 +1,71 @@ | |||
| from typing import Literal, Optional | |||
| from pydantic import BaseModel, Field, field_validator | |||
| from core.datasource.entities.datasource_entities import DatasourceParameter | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.tools.entities.common_entities import I18nObject | |||
| class DatasourceApiEntity(BaseModel): | |||
| author: str | |||
| name: str # identifier | |||
| label: I18nObject # label | |||
| description: I18nObject | |||
| parameters: Optional[list[DatasourceParameter]] = None | |||
| labels: list[str] = Field(default_factory=list) | |||
| output_schema: Optional[dict] = None | |||
| ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] | |||
| class DatasourceProviderApiEntity(BaseModel): | |||
| id: str | |||
| author: str | |||
| name: str # identifier | |||
| description: I18nObject | |||
| icon: str | dict | |||
| label: I18nObject # label | |||
| type: str | |||
| masked_credentials: Optional[dict] = None | |||
| original_credentials: Optional[dict] = None | |||
| is_team_authorization: bool = False | |||
| allow_delete: bool = True | |||
| plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource") | |||
| plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource") | |||
| datasources: list[DatasourceApiEntity] = Field(default_factory=list) | |||
| labels: list[str] = Field(default_factory=list) | |||
| @field_validator("datasources", mode="before") | |||
| @classmethod | |||
| def convert_none_to_empty_list(cls, v): | |||
| return v if v is not None else [] | |||
| def to_dict(self) -> dict: | |||
| # ------------- | |||
| # overwrite datasource parameter types for temp fix | |||
| datasources = jsonable_encoder(self.datasources) | |||
| for datasource in datasources: | |||
| if datasource.get("parameters"): | |||
| for parameter in datasource.get("parameters"): | |||
| if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value: | |||
| parameter["type"] = "files" | |||
| # ------------- | |||
| return { | |||
| "id": self.id, | |||
| "author": self.author, | |||
| "name": self.name, | |||
| "plugin_id": self.plugin_id, | |||
| "plugin_unique_identifier": self.plugin_unique_identifier, | |||
| "description": self.description.to_dict(), | |||
| "icon": self.icon, | |||
| "label": self.label.to_dict(), | |||
| "type": self.type.value, | |||
| "team_credentials": self.masked_credentials, | |||
| "is_team_authorization": self.is_team_authorization, | |||
| "allow_delete": self.allow_delete, | |||
| "datasources": datasources, | |||
| "labels": self.labels, | |||
| } | |||
| @@ -0,0 +1,23 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel, Field | |||
| class I18nObject(BaseModel): | |||
| """ | |||
| Model class for i18n object. | |||
| """ | |||
| en_US: str | |||
| zh_Hans: Optional[str] = Field(default=None) | |||
| pt_BR: Optional[str] = Field(default=None) | |||
| ja_JP: Optional[str] = Field(default=None) | |||
| def __init__(self, **data): | |||
| super().__init__(**data) | |||
| self.zh_Hans = self.zh_Hans or self.en_US | |||
| self.pt_BR = self.pt_BR or self.en_US | |||
| self.ja_JP = self.ja_JP or self.en_US | |||
| def to_dict(self) -> dict: | |||
| return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} | |||
| @@ -0,0 +1,363 @@ | |||
| import enum | |||
| from enum import Enum | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, Field, ValidationInfo, field_validator | |||
| from core.entities.provider_entities import ProviderConfig | |||
| from core.plugin.entities.oauth import OAuthSchema | |||
| from core.plugin.entities.parameters import ( | |||
| PluginParameter, | |||
| PluginParameterOption, | |||
| PluginParameterType, | |||
| as_normal_type, | |||
| cast_parameter_value, | |||
| init_frontend_parameter, | |||
| ) | |||
| from core.tools.entities.common_entities import I18nObject | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum | |||
| class DatasourceProviderType(enum.StrEnum): | |||
| """ | |||
| Enum class for datasource provider | |||
| """ | |||
| ONLINE_DOCUMENT = "online_document" | |||
| LOCAL_FILE = "local_file" | |||
| WEBSITE_CRAWL = "website_crawl" | |||
| ONLINE_DRIVE = "online_drive" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "DatasourceProviderType": | |||
| """ | |||
| Get value of given mode. | |||
| :param value: mode value | |||
| :return: mode | |||
| """ | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f"invalid mode value {value}") | |||
| class DatasourceParameter(PluginParameter): | |||
| """ | |||
| Overrides type | |||
| """ | |||
| class DatasourceParameterType(enum.StrEnum): | |||
| """ | |||
| removes TOOLS_SELECTOR from PluginParameterType | |||
| """ | |||
| STRING = PluginParameterType.STRING.value | |||
| NUMBER = PluginParameterType.NUMBER.value | |||
| BOOLEAN = PluginParameterType.BOOLEAN.value | |||
| SELECT = PluginParameterType.SELECT.value | |||
| SECRET_INPUT = PluginParameterType.SECRET_INPUT.value | |||
| FILE = PluginParameterType.FILE.value | |||
| FILES = PluginParameterType.FILES.value | |||
| # deprecated, should not use. | |||
| SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value | |||
| def as_normal_type(self): | |||
| return as_normal_type(self) | |||
| def cast_value(self, value: Any): | |||
| return cast_parameter_value(self, value) | |||
| type: DatasourceParameterType = Field(..., description="The type of the parameter") | |||
| description: I18nObject = Field(..., description="The description of the parameter") | |||
| @classmethod | |||
| def get_simple_instance( | |||
| cls, | |||
| name: str, | |||
| typ: DatasourceParameterType, | |||
| required: bool, | |||
| options: Optional[list[str]] = None, | |||
| ) -> "DatasourceParameter": | |||
| """ | |||
| get a simple datasource parameter | |||
| :param name: the name of the parameter | |||
| :param llm_description: the description presented to the LLM | |||
| :param typ: the type of the parameter | |||
| :param required: if the parameter is required | |||
| :param options: the options of the parameter | |||
| """ | |||
| # convert options to ToolParameterOption | |||
| # FIXME fix the type error | |||
| if options: | |||
| option_objs = [ | |||
| PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) | |||
| for option in options | |||
| ] | |||
| else: | |||
| option_objs = [] | |||
| return cls( | |||
| name=name, | |||
| label=I18nObject(en_US="", zh_Hans=""), | |||
| placeholder=None, | |||
| type=typ, | |||
| required=required, | |||
| options=option_objs, | |||
| description=I18nObject(en_US="", zh_Hans=""), | |||
| ) | |||
| def init_frontend_parameter(self, value: Any): | |||
| return init_frontend_parameter(self, self.type, value) | |||
| class DatasourceIdentity(BaseModel): | |||
| author: str = Field(..., description="The author of the datasource") | |||
| name: str = Field(..., description="The name of the datasource") | |||
| label: I18nObject = Field(..., description="The label of the datasource") | |||
| provider: str = Field(..., description="The provider of the datasource") | |||
| icon: Optional[str] = None | |||
| class DatasourceEntity(BaseModel): | |||
| identity: DatasourceIdentity | |||
| parameters: list[DatasourceParameter] = Field(default_factory=list) | |||
| description: I18nObject = Field(..., description="The label of the datasource") | |||
| output_schema: Optional[dict] = None | |||
| @field_validator("parameters", mode="before") | |||
| @classmethod | |||
| def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: | |||
| return v or [] | |||
| class DatasourceProviderIdentity(BaseModel): | |||
| author: str = Field(..., description="The author of the tool") | |||
| name: str = Field(..., description="The name of the tool") | |||
| description: I18nObject = Field(..., description="The description of the tool") | |||
| icon: str = Field(..., description="The icon of the tool") | |||
| label: I18nObject = Field(..., description="The label of the tool") | |||
| tags: Optional[list[ToolLabelEnum]] = Field( | |||
| default=[], | |||
| description="The tags of the tool", | |||
| ) | |||
| class DatasourceProviderEntity(BaseModel): | |||
| """ | |||
| Datasource provider entity | |||
| """ | |||
| identity: DatasourceProviderIdentity | |||
| credentials_schema: list[ProviderConfig] = Field(default_factory=list) | |||
| oauth_schema: Optional[OAuthSchema] = None | |||
| provider_type: DatasourceProviderType | |||
| class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): | |||
| datasources: list[DatasourceEntity] = Field(default_factory=list) | |||
| class DatasourceInvokeMeta(BaseModel): | |||
| """ | |||
| Datasource invoke meta | |||
| """ | |||
| time_cost: float = Field(..., description="The time cost of the tool invoke") | |||
| error: Optional[str] = None | |||
| tool_config: Optional[dict] = None | |||
| @classmethod | |||
| def empty(cls) -> "DatasourceInvokeMeta": | |||
| """ | |||
| Get an empty instance of DatasourceInvokeMeta | |||
| """ | |||
| return cls(time_cost=0.0, error=None, tool_config={}) | |||
| @classmethod | |||
| def error_instance(cls, error: str) -> "DatasourceInvokeMeta": | |||
| """ | |||
| Get an instance of DatasourceInvokeMeta with error | |||
| """ | |||
| return cls(time_cost=0.0, error=error, tool_config={}) | |||
| def to_dict(self) -> dict: | |||
| return { | |||
| "time_cost": self.time_cost, | |||
| "error": self.error, | |||
| "tool_config": self.tool_config, | |||
| } | |||
| class DatasourceLabel(BaseModel): | |||
| """ | |||
| Datasource label | |||
| """ | |||
| name: str = Field(..., description="The name of the tool") | |||
| label: I18nObject = Field(..., description="The label of the tool") | |||
| icon: str = Field(..., description="The icon of the tool") | |||
| class DatasourceInvokeFrom(Enum): | |||
| """ | |||
| Enum class for datasource invoke | |||
| """ | |||
| RAG_PIPELINE = "rag_pipeline" | |||
| class OnlineDocumentPage(BaseModel): | |||
| """ | |||
| Online document page | |||
| """ | |||
| page_id: str = Field(..., description="The page id") | |||
| page_name: str = Field(..., description="The page title") | |||
| page_icon: Optional[dict] = Field(None, description="The page icon") | |||
| type: str = Field(..., description="The type of the page") | |||
| last_edited_time: str = Field(..., description="The last edited time") | |||
| parent_id: Optional[str] = Field(None, description="The parent page id") | |||
| class OnlineDocumentInfo(BaseModel): | |||
| """ | |||
| Online document info | |||
| """ | |||
| workspace_id: Optional[str] = Field(None, description="The workspace id") | |||
| workspace_name: Optional[str] = Field(None, description="The workspace name") | |||
| workspace_icon: Optional[str] = Field(None, description="The workspace icon") | |||
| total: int = Field(..., description="The total number of documents") | |||
| pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") | |||
| class OnlineDocumentPagesMessage(BaseModel): | |||
| """ | |||
| Get online document pages response | |||
| """ | |||
| result: list[OnlineDocumentInfo] | |||
| class GetOnlineDocumentPageContentRequest(BaseModel): | |||
| """ | |||
| Get online document page content request | |||
| """ | |||
| workspace_id: str = Field(..., description="The workspace id") | |||
| page_id: str = Field(..., description="The page id") | |||
| type: str = Field(..., description="The type of the page") | |||
| class OnlineDocumentPageContent(BaseModel): | |||
| """ | |||
| Online document page content | |||
| """ | |||
| workspace_id: str = Field(..., description="The workspace id") | |||
| page_id: str = Field(..., description="The page id") | |||
| content: str = Field(..., description="The content of the page") | |||
| class GetOnlineDocumentPageContentResponse(BaseModel): | |||
| """ | |||
| Get online document page content response | |||
| """ | |||
| result: OnlineDocumentPageContent | |||
| class GetWebsiteCrawlRequest(BaseModel): | |||
| """ | |||
| Get website crawl request | |||
| """ | |||
| crawl_parameters: dict = Field(..., description="The crawl parameters") | |||
| class WebSiteInfoDetail(BaseModel): | |||
| source_url: str = Field(..., description="The url of the website") | |||
| content: str = Field(..., description="The content of the website") | |||
| title: str = Field(..., description="The title of the website") | |||
| description: str = Field(..., description="The description of the website") | |||
| class WebSiteInfo(BaseModel): | |||
| """ | |||
| Website info | |||
| """ | |||
| status: Optional[str] = Field(..., description="crawl job status") | |||
| web_info_list: Optional[list[WebSiteInfoDetail]] = [] | |||
| total: Optional[int] = Field(default=0, description="The total number of websites") | |||
| completed: Optional[int] = Field(default=0, description="The number of completed websites") | |||
| class WebsiteCrawlMessage(BaseModel): | |||
| """ | |||
| Get website crawl response | |||
| """ | |||
| result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0) | |||
| class DatasourceMessage(ToolInvokeMessage): | |||
| pass | |||
| ######################### | |||
| # Online drive file | |||
| ######################### | |||
| class OnlineDriveFile(BaseModel): | |||
| """ | |||
| Online drive file | |||
| """ | |||
| id: str = Field(..., description="The file ID") | |||
| name: str = Field(..., description="The file name") | |||
| size: int = Field(..., description="The file size") | |||
| type: str = Field(..., description="The file type: folder or file") | |||
| class OnlineDriveFileBucket(BaseModel): | |||
| """ | |||
| Online drive file bucket | |||
| """ | |||
| bucket: Optional[str] = Field(None, description="The file bucket") | |||
| files: list[OnlineDriveFile] = Field(..., description="The file list") | |||
| is_truncated: bool = Field(False, description="Whether the result is truncated") | |||
| next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page") | |||
| class OnlineDriveBrowseFilesRequest(BaseModel): | |||
| """ | |||
| Get online drive file list request | |||
| """ | |||
| bucket: Optional[str] = Field(None, description="The file bucket") | |||
| prefix: str = Field(..., description="The parent folder ID") | |||
| max_keys: int = Field(20, description="Page size for pagination") | |||
| next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page") | |||
| class OnlineDriveBrowseFilesResponse(BaseModel): | |||
| """ | |||
| Get online drive file list response | |||
| """ | |||
| result: list[OnlineDriveFileBucket] = Field(..., description="The list of file buckets") | |||
| class OnlineDriveDownloadFileRequest(BaseModel): | |||
| """ | |||
| Get online drive file | |||
| """ | |||
| id: str = Field(..., description="The id of the file") | |||
| bucket: Optional[str] = Field(None, description="The name of the bucket") | |||
| @@ -0,0 +1,37 @@ | |||
| from core.datasource.entities.datasource_entities import DatasourceInvokeMeta | |||
| class DatasourceProviderNotFoundError(ValueError): | |||
| pass | |||
| class DatasourceNotFoundError(ValueError): | |||
| pass | |||
| class DatasourceParameterValidationError(ValueError): | |||
| pass | |||
| class DatasourceProviderCredentialValidationError(ValueError): | |||
| pass | |||
| class DatasourceNotSupportedError(ValueError): | |||
| pass | |||
| class DatasourceInvokeError(ValueError): | |||
| pass | |||
| class DatasourceApiSchemaError(ValueError): | |||
| pass | |||
| class DatasourceEngineInvokeError(Exception): | |||
| meta: DatasourceInvokeMeta | |||
| def __init__(self, meta, **kwargs): | |||
| self.meta = meta | |||
| super().__init__(**kwargs) | |||
| @@ -0,0 +1,28 @@ | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import ( | |||
| DatasourceEntity, | |||
| DatasourceProviderType, | |||
| ) | |||
| class LocalFileDatasourcePlugin(DatasourcePlugin): | |||
| tenant_id: str | |||
| icon: str | |||
| plugin_unique_identifier: str | |||
| def __init__( | |||
| self, | |||
| entity: DatasourceEntity, | |||
| runtime: DatasourceRuntime, | |||
| tenant_id: str, | |||
| icon: str, | |||
| plugin_unique_identifier: str, | |||
| ) -> None: | |||
| super().__init__(entity, runtime) | |||
| self.tenant_id = tenant_id | |||
| self.icon = icon | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| def datasource_provider_type(self) -> str: | |||
| return DatasourceProviderType.LOCAL_FILE | |||
| @@ -0,0 +1,56 @@ | |||
| from typing import Any | |||
| from core.datasource.__base.datasource_provider import DatasourcePluginProviderController | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType | |||
| from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin | |||
| class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController): | |||
| entity: DatasourceProviderEntityWithPlugin | |||
| plugin_id: str | |||
| plugin_unique_identifier: str | |||
| def __init__( | |||
| self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str | |||
| ) -> None: | |||
| super().__init__(entity, tenant_id) | |||
| self.plugin_id = plugin_id | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| @property | |||
| def provider_type(self) -> DatasourceProviderType: | |||
| """ | |||
| returns the type of the provider | |||
| """ | |||
| return DatasourceProviderType.LOCAL_FILE | |||
| def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: | |||
| """ | |||
| validate the credentials of the provider | |||
| """ | |||
| pass | |||
| def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore | |||
| """ | |||
| return datasource with given name | |||
| """ | |||
| datasource_entity = next( | |||
| ( | |||
| datasource_entity | |||
| for datasource_entity in self.entity.datasources | |||
| if datasource_entity.identity.name == datasource_name | |||
| ), | |||
| None, | |||
| ) | |||
| if not datasource_entity: | |||
| raise ValueError(f"Datasource with name {datasource_name} not found") | |||
| return LocalFileDatasourcePlugin( | |||
| entity=datasource_entity, | |||
| runtime=DatasourceRuntime(tenant_id=self.tenant_id), | |||
| tenant_id=self.tenant_id, | |||
| icon=self.entity.identity.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| @@ -0,0 +1,73 @@ | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Any | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import ( | |||
| DatasourceEntity, | |||
| DatasourceMessage, | |||
| DatasourceProviderType, | |||
| GetOnlineDocumentPageContentRequest, | |||
| OnlineDocumentPagesMessage, | |||
| ) | |||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||
| class OnlineDocumentDatasourcePlugin(DatasourcePlugin): | |||
| tenant_id: str | |||
| icon: str | |||
| plugin_unique_identifier: str | |||
| entity: DatasourceEntity | |||
| runtime: DatasourceRuntime | |||
| def __init__( | |||
| self, | |||
| entity: DatasourceEntity, | |||
| runtime: DatasourceRuntime, | |||
| tenant_id: str, | |||
| icon: str, | |||
| plugin_unique_identifier: str, | |||
| ) -> None: | |||
| super().__init__(entity, runtime) | |||
| self.tenant_id = tenant_id | |||
| self.icon = icon | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| def get_online_document_pages( | |||
| self, | |||
| user_id: str, | |||
| datasource_parameters: Mapping[str, Any], | |||
| provider_type: str, | |||
| ) -> Generator[OnlineDocumentPagesMessage, None, None]: | |||
| manager = PluginDatasourceManager() | |||
| return manager.get_online_document_pages( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| datasource_parameters=datasource_parameters, | |||
| provider_type=provider_type, | |||
| ) | |||
| def get_online_document_page_content( | |||
| self, | |||
| user_id: str, | |||
| datasource_parameters: GetOnlineDocumentPageContentRequest, | |||
| provider_type: str, | |||
| ) -> Generator[DatasourceMessage, None, None]: | |||
| manager = PluginDatasourceManager() | |||
| return manager.get_online_document_page_content( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| datasource_parameters=datasource_parameters, | |||
| provider_type=provider_type, | |||
| ) | |||
| def datasource_provider_type(self) -> str: | |||
| return DatasourceProviderType.ONLINE_DOCUMENT | |||
| @@ -0,0 +1,48 @@ | |||
| from core.datasource.__base.datasource_provider import DatasourcePluginProviderController | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType | |||
| from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin | |||
| class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController): | |||
| entity: DatasourceProviderEntityWithPlugin | |||
| plugin_id: str | |||
| plugin_unique_identifier: str | |||
| def __init__( | |||
| self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str | |||
| ) -> None: | |||
| super().__init__(entity, tenant_id) | |||
| self.plugin_id = plugin_id | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| @property | |||
| def provider_type(self) -> DatasourceProviderType: | |||
| """ | |||
| returns the type of the provider | |||
| """ | |||
| return DatasourceProviderType.ONLINE_DOCUMENT | |||
| def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore | |||
| """ | |||
| return datasource with given name | |||
| """ | |||
| datasource_entity = next( | |||
| ( | |||
| datasource_entity | |||
| for datasource_entity in self.entity.datasources | |||
| if datasource_entity.identity.name == datasource_name | |||
| ), | |||
| None, | |||
| ) | |||
| if not datasource_entity: | |||
| raise ValueError(f"Datasource with name {datasource_name} not found") | |||
| return OnlineDocumentDatasourcePlugin( | |||
| entity=datasource_entity, | |||
| runtime=DatasourceRuntime(tenant_id=self.tenant_id), | |||
| tenant_id=self.tenant_id, | |||
| icon=self.entity.identity.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| @@ -0,0 +1,73 @@ | |||
| from collections.abc import Generator | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import ( | |||
| DatasourceEntity, | |||
| DatasourceMessage, | |||
| DatasourceProviderType, | |||
| OnlineDriveBrowseFilesRequest, | |||
| OnlineDriveBrowseFilesResponse, | |||
| OnlineDriveDownloadFileRequest, | |||
| ) | |||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||
| class OnlineDriveDatasourcePlugin(DatasourcePlugin): | |||
| tenant_id: str | |||
| icon: str | |||
| plugin_unique_identifier: str | |||
| entity: DatasourceEntity | |||
| runtime: DatasourceRuntime | |||
| def __init__( | |||
| self, | |||
| entity: DatasourceEntity, | |||
| runtime: DatasourceRuntime, | |||
| tenant_id: str, | |||
| icon: str, | |||
| plugin_unique_identifier: str, | |||
| ) -> None: | |||
| super().__init__(entity, runtime) | |||
| self.tenant_id = tenant_id | |||
| self.icon = icon | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| def online_drive_browse_files( | |||
| self, | |||
| user_id: str, | |||
| request: OnlineDriveBrowseFilesRequest, | |||
| provider_type: str, | |||
| ) -> Generator[OnlineDriveBrowseFilesResponse, None, None]: | |||
| manager = PluginDatasourceManager() | |||
| return manager.online_drive_browse_files( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| request=request, | |||
| provider_type=provider_type, | |||
| ) | |||
| def online_drive_download_file( | |||
| self, | |||
| user_id: str, | |||
| request: OnlineDriveDownloadFileRequest, | |||
| provider_type: str, | |||
| ) -> Generator[DatasourceMessage, None, None]: | |||
| manager = PluginDatasourceManager() | |||
| return manager.online_drive_download_file( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| request=request, | |||
| provider_type=provider_type, | |||
| ) | |||
| def datasource_provider_type(self) -> str: | |||
| return DatasourceProviderType.ONLINE_DRIVE | |||
| @@ -0,0 +1,48 @@ | |||
| from core.datasource.__base.datasource_provider import DatasourcePluginProviderController | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType | |||
| from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin | |||
| class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController): | |||
| entity: DatasourceProviderEntityWithPlugin | |||
| plugin_id: str | |||
| plugin_unique_identifier: str | |||
| def __init__( | |||
| self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str | |||
| ) -> None: | |||
| super().__init__(entity, tenant_id) | |||
| self.plugin_id = plugin_id | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| @property | |||
| def provider_type(self) -> DatasourceProviderType: | |||
| """ | |||
| returns the type of the provider | |||
| """ | |||
| return DatasourceProviderType.ONLINE_DRIVE | |||
| def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore | |||
| """ | |||
| return datasource with given name | |||
| """ | |||
| datasource_entity = next( | |||
| ( | |||
| datasource_entity | |||
| for datasource_entity in self.entity.datasources | |||
| if datasource_entity.identity.name == datasource_name | |||
| ), | |||
| None, | |||
| ) | |||
| if not datasource_entity: | |||
| raise ValueError(f"Datasource with name {datasource_name} not found") | |||
| return OnlineDriveDatasourcePlugin( | |||
| entity=datasource_entity, | |||
| runtime=DatasourceRuntime(tenant_id=self.tenant_id), | |||
| tenant_id=self.tenant_id, | |||
| icon=self.entity.identity.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| @@ -0,0 +1,265 @@ | |||
| from copy import deepcopy | |||
| from typing import Any | |||
| from pydantic import BaseModel | |||
| from core.entities.provider_entities import BasicProviderConfig | |||
| from core.helper import encrypter | |||
| from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType | |||
| from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType | |||
| from core.tools.__base.tool import Tool | |||
| from core.tools.entities.tool_entities import ( | |||
| ToolParameter, | |||
| ToolProviderType, | |||
| ) | |||
| class ProviderConfigEncrypter(BaseModel): | |||
| tenant_id: str | |||
| config: list[BasicProviderConfig] | |||
| provider_type: str | |||
| provider_identity: str | |||
| def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: | |||
| """ | |||
| deep copy data | |||
| """ | |||
| return deepcopy(data) | |||
| def encrypt(self, data: dict[str, str]) -> dict[str, str]: | |||
| """ | |||
| encrypt tool credentials with tenant id | |||
| return a deep copy of credentials with encrypted values | |||
| """ | |||
| data = self._deep_copy(data) | |||
| # get fields need to be decrypted | |||
| fields = dict[str, BasicProviderConfig]() | |||
| for credential in self.config: | |||
| fields[credential.name] = credential | |||
| for field_name, field in fields.items(): | |||
| if field.type == BasicProviderConfig.Type.SECRET_INPUT: | |||
| if field_name in data: | |||
| encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") | |||
| data[field_name] = encrypted | |||
| return data | |||
| def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| mask tool credentials | |||
| return a deep copy of credentials with masked values | |||
| """ | |||
| data = self._deep_copy(data) | |||
| # get fields need to be decrypted | |||
| fields = dict[str, BasicProviderConfig]() | |||
| for credential in self.config: | |||
| fields[credential.name] = credential | |||
| for field_name, field in fields.items(): | |||
| if field.type == BasicProviderConfig.Type.SECRET_INPUT: | |||
| if field_name in data: | |||
| if len(data[field_name]) > 6: | |||
| data[field_name] = ( | |||
| data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] | |||
| ) | |||
| else: | |||
| data[field_name] = "*" * len(data[field_name]) | |||
| return data | |||
| def decrypt(self, data: dict[str, str]) -> dict[str, str]: | |||
| """ | |||
| decrypt tool credentials with tenant id | |||
| return a deep copy of credentials with decrypted values | |||
| """ | |||
| cache = ToolProviderCredentialsCache( | |||
| tenant_id=self.tenant_id, | |||
| identity_id=f"{self.provider_type}.{self.provider_identity}", | |||
| cache_type=ToolProviderCredentialsCacheType.PROVIDER, | |||
| ) | |||
| cached_credentials = cache.get() | |||
| if cached_credentials: | |||
| return cached_credentials | |||
| data = self._deep_copy(data) | |||
| # get fields need to be decrypted | |||
| fields = dict[str, BasicProviderConfig]() | |||
| for credential in self.config: | |||
| fields[credential.name] = credential | |||
| for field_name, field in fields.items(): | |||
| if field.type == BasicProviderConfig.Type.SECRET_INPUT: | |||
| if field_name in data: | |||
| try: | |||
| # if the value is None or empty string, skip decrypt | |||
| if not data[field_name]: | |||
| continue | |||
| data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) | |||
| except Exception: | |||
| pass | |||
| cache.set(data) | |||
| return data | |||
| def delete_tool_credentials_cache(self): | |||
| cache = ToolProviderCredentialsCache( | |||
| tenant_id=self.tenant_id, | |||
| identity_id=f"{self.provider_type}.{self.provider_identity}", | |||
| cache_type=ToolProviderCredentialsCacheType.PROVIDER, | |||
| ) | |||
| cache.delete() | |||
| class ToolParameterConfigurationManager: | |||
| """ | |||
| Tool parameter configuration manager | |||
| """ | |||
| tenant_id: str | |||
| tool_runtime: Tool | |||
| provider_name: str | |||
| provider_type: ToolProviderType | |||
| identity_id: str | |||
| def __init__( | |||
| self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str | |||
| ) -> None: | |||
| self.tenant_id = tenant_id | |||
| self.tool_runtime = tool_runtime | |||
| self.provider_name = provider_name | |||
| self.provider_type = provider_type | |||
| self.identity_id = identity_id | |||
| def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| deep copy parameters | |||
| """ | |||
| return deepcopy(parameters) | |||
| def _merge_parameters(self) -> list[ToolParameter]: | |||
| """ | |||
| merge parameters | |||
| """ | |||
| # get tool parameters | |||
| tool_parameters = self.tool_runtime.entity.parameters or [] | |||
| # get tool runtime parameters | |||
| runtime_parameters = self.tool_runtime.get_runtime_parameters() | |||
| # override parameters | |||
| current_parameters = tool_parameters.copy() | |||
| for runtime_parameter in runtime_parameters: | |||
| found = False | |||
| for index, parameter in enumerate(current_parameters): | |||
| if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: | |||
| current_parameters[index] = runtime_parameter | |||
| found = True | |||
| break | |||
| if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: | |||
| current_parameters.append(runtime_parameter) | |||
| return current_parameters | |||
| def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| mask tool parameters | |||
| return a deep copy of parameters with masked values | |||
| """ | |||
| parameters = self._deep_copy(parameters) | |||
| # override parameters | |||
| current_parameters = self._merge_parameters() | |||
| for parameter in current_parameters: | |||
| if ( | |||
| parameter.form == ToolParameter.ToolParameterForm.FORM | |||
| and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT | |||
| ): | |||
| if parameter.name in parameters: | |||
| if len(parameters[parameter.name]) > 6: | |||
| parameters[parameter.name] = ( | |||
| parameters[parameter.name][:2] | |||
| + "*" * (len(parameters[parameter.name]) - 4) | |||
| + parameters[parameter.name][-2:] | |||
| ) | |||
| else: | |||
| parameters[parameter.name] = "*" * len(parameters[parameter.name]) | |||
| return parameters | |||
| def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| encrypt tool parameters with tenant id | |||
| return a deep copy of parameters with encrypted values | |||
| """ | |||
| # override parameters | |||
| current_parameters = self._merge_parameters() | |||
| parameters = self._deep_copy(parameters) | |||
| for parameter in current_parameters: | |||
| if ( | |||
| parameter.form == ToolParameter.ToolParameterForm.FORM | |||
| and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT | |||
| ): | |||
| if parameter.name in parameters: | |||
| encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) | |||
| parameters[parameter.name] = encrypted | |||
| return parameters | |||
| def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| decrypt tool parameters with tenant id | |||
| return a deep copy of parameters with decrypted values | |||
| """ | |||
| cache = ToolParameterCache( | |||
| tenant_id=self.tenant_id, | |||
| provider=f"{self.provider_type.value}.{self.provider_name}", | |||
| tool_name=self.tool_runtime.entity.identity.name, | |||
| cache_type=ToolParameterCacheType.PARAMETER, | |||
| identity_id=self.identity_id, | |||
| ) | |||
| cached_parameters = cache.get() | |||
| if cached_parameters: | |||
| return cached_parameters | |||
| # override parameters | |||
| current_parameters = self._merge_parameters() | |||
| has_secret_input = False | |||
| for parameter in current_parameters: | |||
| if ( | |||
| parameter.form == ToolParameter.ToolParameterForm.FORM | |||
| and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT | |||
| ): | |||
| if parameter.name in parameters: | |||
| try: | |||
| has_secret_input = True | |||
| parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) | |||
| except Exception: | |||
| pass | |||
| if has_secret_input: | |||
| cache.set(parameters) | |||
| return parameters | |||
| def delete_tool_parameters_cache(self): | |||
| cache = ToolParameterCache( | |||
| tenant_id=self.tenant_id, | |||
| provider=f"{self.provider_type.value}.{self.provider_name}", | |||
| tool_name=self.tool_runtime.entity.identity.name, | |||
| cache_type=ToolParameterCacheType.PARAMETER, | |||
| identity_id=self.identity_id, | |||
| ) | |||
| cache.delete() | |||
| @@ -0,0 +1,124 @@ | |||
| import logging | |||
| from collections.abc import Generator | |||
| from mimetypes import guess_extension, guess_type | |||
| from typing import Optional | |||
| from core.datasource.datasource_file_manager import DatasourceFileManager | |||
| from core.datasource.entities.datasource_entities import DatasourceMessage | |||
| from core.file import File, FileTransferMethod, FileType | |||
| logger = logging.getLogger(__name__) | |||
| class DatasourceFileMessageTransformer: | |||
| @classmethod | |||
| def transform_datasource_invoke_messages( | |||
| cls, | |||
| messages: Generator[DatasourceMessage, None, None], | |||
| user_id: str, | |||
| tenant_id: str, | |||
| conversation_id: Optional[str] = None, | |||
| ) -> Generator[DatasourceMessage, None, None]: | |||
| """ | |||
| Transform datasource message and handle file download | |||
| """ | |||
| for message in messages: | |||
| if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}: | |||
| yield message | |||
| elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance( | |||
| message.message, DatasourceMessage.TextMessage | |||
| ): | |||
| # try to download image | |||
| try: | |||
| assert isinstance(message.message, DatasourceMessage.TextMessage) | |||
| file = DatasourceFileManager.create_file_by_url( | |||
| user_id=user_id, | |||
| tenant_id=tenant_id, | |||
| file_url=message.message.text, | |||
| conversation_id=conversation_id, | |||
| ) | |||
| url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}" | |||
| yield DatasourceMessage( | |||
| type=DatasourceMessage.MessageType.IMAGE_LINK, | |||
| message=DatasourceMessage.TextMessage(text=url), | |||
| meta=message.meta.copy() if message.meta is not None else {}, | |||
| ) | |||
| except Exception as e: | |||
| yield DatasourceMessage( | |||
| type=DatasourceMessage.MessageType.TEXT, | |||
| message=DatasourceMessage.TextMessage( | |||
| text=f"Failed to download image: {message.message.text}: {e}" | |||
| ), | |||
| meta=message.meta.copy() if message.meta is not None else {}, | |||
| ) | |||
| elif message.type == DatasourceMessage.MessageType.BLOB: | |||
| # get mime type and save blob to storage | |||
| meta = message.meta or {} | |||
| # get filename from meta | |||
| filename = meta.get("file_name", None) | |||
| mimetype = meta.get("mime_type") | |||
| if not mimetype: | |||
| mimetype = guess_type(filename)[0] or "application/octet-stream" | |||
| # if message is str, encode it to bytes | |||
| if not isinstance(message.message, DatasourceMessage.BlobMessage): | |||
| raise ValueError("unexpected message type") | |||
| # FIXME: should do a type check here. | |||
| assert isinstance(message.message.blob, bytes) | |||
| file = DatasourceFileManager.create_file_by_raw( | |||
| user_id=user_id, | |||
| tenant_id=tenant_id, | |||
| conversation_id=conversation_id, | |||
| file_binary=message.message.blob, | |||
| mimetype=mimetype, | |||
| filename=filename, | |||
| ) | |||
| url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type)) | |||
| # check if file is image | |||
| if "image" in mimetype: | |||
| yield DatasourceMessage( | |||
| type=DatasourceMessage.MessageType.IMAGE_LINK, | |||
| message=DatasourceMessage.TextMessage(text=url), | |||
| meta=meta.copy() if meta is not None else {}, | |||
| ) | |||
| else: | |||
| yield DatasourceMessage( | |||
| type=DatasourceMessage.MessageType.BINARY_LINK, | |||
| message=DatasourceMessage.TextMessage(text=url), | |||
| meta=meta.copy() if meta is not None else {}, | |||
| ) | |||
| elif message.type == DatasourceMessage.MessageType.FILE: | |||
| meta = message.meta or {} | |||
| file = meta.get("file", None) | |||
| if isinstance(file, File): | |||
| if file.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| assert file.related_id is not None | |||
| url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) | |||
| if file.type == FileType.IMAGE: | |||
| yield DatasourceMessage( | |||
| type=DatasourceMessage.MessageType.IMAGE_LINK, | |||
| message=DatasourceMessage.TextMessage(text=url), | |||
| meta=meta.copy() if meta is not None else {}, | |||
| ) | |||
| else: | |||
| yield DatasourceMessage( | |||
| type=DatasourceMessage.MessageType.LINK, | |||
| message=DatasourceMessage.TextMessage(text=url), | |||
| meta=meta.copy() if meta is not None else {}, | |||
| ) | |||
| else: | |||
| yield message | |||
| else: | |||
| yield message | |||
| @classmethod | |||
| def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str: | |||
| return f"/files/datasources/{datasource_file_id}{extension or '.bin'}" | |||
| @@ -0,0 +1,389 @@ | |||
| import re | |||
| import uuid | |||
| from json import dumps as json_dumps | |||
| from json import loads as json_loads | |||
| from json.decoder import JSONDecodeError | |||
| from typing import Optional | |||
| from flask import request | |||
| from requests import get | |||
| from yaml import YAMLError, safe_load # type: ignore | |||
| from core.tools.entities.common_entities import I18nObject | |||
| from core.tools.entities.tool_bundle import ApiToolBundle | |||
| from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter | |||
| from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError | |||
| class ApiBasedToolSchemaParser: | |||
| @staticmethod | |||
| def parse_openapi_to_tool_bundle( | |||
| openapi: dict, extra_info: dict | None = None, warning: dict | None = None | |||
| ) -> list[ApiToolBundle]: | |||
| warning = warning if warning is not None else {} | |||
| extra_info = extra_info if extra_info is not None else {} | |||
| # set description to extra_info | |||
| extra_info["description"] = openapi["info"].get("description", "") | |||
| if len(openapi["servers"]) == 0: | |||
| raise ToolProviderNotFoundError("No server found in the openapi yaml.") | |||
| server_url = openapi["servers"][0]["url"] | |||
| request_env = request.headers.get("X-Request-Env") | |||
| if request_env: | |||
| matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env] | |||
| server_url = matched_servers[0] if matched_servers else server_url | |||
| # list all interfaces | |||
| interfaces = [] | |||
| for path, path_item in openapi["paths"].items(): | |||
| methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"] | |||
| for method in methods: | |||
| if method in path_item: | |||
| interfaces.append( | |||
| { | |||
| "path": path, | |||
| "method": method, | |||
| "operation": path_item[method], | |||
| } | |||
| ) | |||
| # get all parameters | |||
| bundles = [] | |||
| for interface in interfaces: | |||
| # convert parameters | |||
| parameters = [] | |||
| if "parameters" in interface["operation"]: | |||
| for parameter in interface["operation"]["parameters"]: | |||
| tool_parameter = ToolParameter( | |||
| name=parameter["name"], | |||
| label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), | |||
| human_description=I18nObject( | |||
| en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") | |||
| ), | |||
| type=ToolParameter.ToolParameterType.STRING, | |||
| required=parameter.get("required", False), | |||
| form=ToolParameter.ToolParameterForm.LLM, | |||
| llm_description=parameter.get("description"), | |||
| default=parameter["schema"]["default"] | |||
| if "schema" in parameter and "default" in parameter["schema"] | |||
| else None, | |||
| placeholder=I18nObject( | |||
| en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") | |||
| ), | |||
| ) | |||
| # check if there is a type | |||
| typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter) | |||
| if typ: | |||
| tool_parameter.type = typ | |||
| parameters.append(tool_parameter) | |||
| # create tool bundle | |||
| # check if there is a request body | |||
| if "requestBody" in interface["operation"]: | |||
| request_body = interface["operation"]["requestBody"] | |||
| if "content" in request_body: | |||
| for content_type, content in request_body["content"].items(): | |||
| # if there is a reference, get the reference and overwrite the content | |||
| if "schema" not in content: | |||
| continue | |||
| if "$ref" in content["schema"]: | |||
| # get the reference | |||
| root = openapi | |||
| reference = content["schema"]["$ref"].split("/")[1:] | |||
| for ref in reference: | |||
| root = root[ref] | |||
| # overwrite the content | |||
| interface["operation"]["requestBody"]["content"][content_type]["schema"] = root | |||
| # parse body parameters | |||
| if "schema" in interface["operation"]["requestBody"]["content"][content_type]: | |||
| body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] | |||
| required = body_schema.get("required", []) | |||
| properties = body_schema.get("properties", {}) | |||
| for name, property in properties.items(): | |||
| tool = ToolParameter( | |||
| name=name, | |||
| label=I18nObject(en_US=name, zh_Hans=name), | |||
| human_description=I18nObject( | |||
| en_US=property.get("description", ""), zh_Hans=property.get("description", "") | |||
| ), | |||
| type=ToolParameter.ToolParameterType.STRING, | |||
| required=name in required, | |||
| form=ToolParameter.ToolParameterForm.LLM, | |||
| llm_description=property.get("description", ""), | |||
| default=property.get("default", None), | |||
| placeholder=I18nObject( | |||
| en_US=property.get("description", ""), zh_Hans=property.get("description", "") | |||
| ), | |||
| ) | |||
| # check if there is a type | |||
| typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) | |||
| if typ: | |||
| tool.type = typ | |||
| parameters.append(tool) | |||
| # check if parameters is duplicated | |||
| parameters_count = {} | |||
| for parameter in parameters: | |||
| if parameter.name not in parameters_count: | |||
| parameters_count[parameter.name] = 0 | |||
| parameters_count[parameter.name] += 1 | |||
| for name, count in parameters_count.items(): | |||
| if count > 1: | |||
| warning["duplicated_parameter"] = f"Parameter {name} is duplicated." | |||
| # check if there is a operation id, use $path_$method as operation id if not | |||
| if "operationId" not in interface["operation"]: | |||
| # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ | |||
| path = interface["path"] | |||
| if interface["path"].startswith("/"): | |||
| path = interface["path"][1:] | |||
| # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ | |||
| path = re.sub(r"[^a-zA-Z0-9_-]", "", path) | |||
| if not path: | |||
| path = str(uuid.uuid4()) | |||
| interface["operation"]["operationId"] = f"{path}_{interface['method']}" | |||
| bundles.append( | |||
| ApiToolBundle( | |||
| server_url=server_url + interface["path"], | |||
| method=interface["method"], | |||
| summary=interface["operation"]["description"] | |||
| if "description" in interface["operation"] | |||
| else interface["operation"].get("summary", None), | |||
| operation_id=interface["operation"]["operationId"], | |||
| parameters=parameters, | |||
| author="", | |||
| icon=None, | |||
| openapi=interface["operation"], | |||
| ) | |||
| ) | |||
| return bundles | |||
| @staticmethod | |||
| def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: | |||
| parameter = parameter or {} | |||
| typ: Optional[str] = None | |||
| if parameter.get("format") == "binary": | |||
| return ToolParameter.ToolParameterType.FILE | |||
| if "type" in parameter: | |||
| typ = parameter["type"] | |||
| elif "schema" in parameter and "type" in parameter["schema"]: | |||
| typ = parameter["schema"]["type"] | |||
| if typ in {"integer", "number"}: | |||
| return ToolParameter.ToolParameterType.NUMBER | |||
| elif typ == "boolean": | |||
| return ToolParameter.ToolParameterType.BOOLEAN | |||
| elif typ == "string": | |||
| return ToolParameter.ToolParameterType.STRING | |||
| elif typ == "array": | |||
| items = parameter.get("items") or parameter.get("schema", {}).get("items") | |||
| return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None | |||
| else: | |||
| return None | |||
| @staticmethod | |||
| def parse_openapi_yaml_to_tool_bundle( | |||
| yaml: str, extra_info: dict | None = None, warning: dict | None = None | |||
| ) -> list[ApiToolBundle]: | |||
| """ | |||
| parse openapi yaml to tool bundle | |||
| :param yaml: the yaml string | |||
| :param extra_info: the extra info | |||
| :param warning: the warning message | |||
| :return: the tool bundle | |||
| """ | |||
| warning = warning if warning is not None else {} | |||
| extra_info = extra_info if extra_info is not None else {} | |||
| openapi: dict = safe_load(yaml) | |||
| if openapi is None: | |||
| raise ToolApiSchemaError("Invalid openapi yaml.") | |||
| return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) | |||
| @staticmethod | |||
| def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict: | |||
| warning = warning or {} | |||
| """ | |||
| parse swagger to openapi | |||
| :param swagger: the swagger dict | |||
| :return: the openapi dict | |||
| """ | |||
| # convert swagger to openapi | |||
| info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"}) | |||
| servers = swagger.get("servers", []) | |||
| if len(servers) == 0: | |||
| raise ToolApiSchemaError("No server found in the swagger yaml.") | |||
| openapi = { | |||
| "openapi": "3.0.0", | |||
| "info": { | |||
| "title": info.get("title", "Swagger"), | |||
| "description": info.get("description", "Swagger"), | |||
| "version": info.get("version", "1.0.0"), | |||
| }, | |||
| "servers": swagger["servers"], | |||
| "paths": {}, | |||
| "components": {"schemas": {}}, | |||
| } | |||
| # check paths | |||
| if "paths" not in swagger or len(swagger["paths"]) == 0: | |||
| raise ToolApiSchemaError("No paths found in the swagger yaml.") | |||
| # convert paths | |||
| for path, path_item in swagger["paths"].items(): | |||
| openapi["paths"][path] = {} | |||
| for method, operation in path_item.items(): | |||
| if "operationId" not in operation: | |||
| raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") | |||
| if ("summary" not in operation or len(operation["summary"]) == 0) and ( | |||
| "description" not in operation or len(operation["description"]) == 0 | |||
| ): | |||
| if warning is not None: | |||
| warning["missing_summary"] = f"No summary or description found in operation {method} {path}." | |||
| openapi["paths"][path][method] = { | |||
| "operationId": operation["operationId"], | |||
| "summary": operation.get("summary", ""), | |||
| "description": operation.get("description", ""), | |||
| "parameters": operation.get("parameters", []), | |||
| "responses": operation.get("responses", {}), | |||
| } | |||
| if "requestBody" in operation: | |||
| openapi["paths"][path][method]["requestBody"] = operation["requestBody"] | |||
| # convert definitions | |||
| for name, definition in swagger["definitions"].items(): | |||
| openapi["components"]["schemas"][name] = definition | |||
| return openapi | |||
| @staticmethod | |||
| def parse_openai_plugin_json_to_tool_bundle( | |||
| json: str, extra_info: dict | None = None, warning: dict | None = None | |||
| ) -> list[ApiToolBundle]: | |||
| """ | |||
| parse openapi plugin yaml to tool bundle | |||
| :param json: the json string | |||
| :param extra_info: the extra info | |||
| :param warning: the warning message | |||
| :return: the tool bundle | |||
| """ | |||
| warning = warning if warning is not None else {} | |||
| extra_info = extra_info if extra_info is not None else {} | |||
| try: | |||
| openai_plugin = json_loads(json) | |||
| api = openai_plugin["api"] | |||
| api_url = api["url"] | |||
| api_type = api["type"] | |||
| except JSONDecodeError: | |||
| raise ToolProviderNotFoundError("Invalid openai plugin json.") | |||
| if api_type != "openapi": | |||
| raise ToolNotSupportedError("Only openapi is supported now.") | |||
| # get openapi yaml | |||
| response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) | |||
| if response.status_code != 200: | |||
| raise ToolProviderNotFoundError("cannot get openapi yaml from url.") | |||
| return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( | |||
| response.text, extra_info=extra_info, warning=warning | |||
| ) | |||
| @staticmethod | |||
| def auto_parse_to_tool_bundle( | |||
| content: str, extra_info: dict | None = None, warning: dict | None = None | |||
| ) -> tuple[list[ApiToolBundle], str]: | |||
| """ | |||
| auto parse to tool bundle | |||
| :param content: the content | |||
| :param extra_info: the extra info | |||
| :param warning: the warning message | |||
| :return: tools bundle, schema_type | |||
| """ | |||
| warning = warning if warning is not None else {} | |||
| extra_info = extra_info if extra_info is not None else {} | |||
| content = content.strip() | |||
| loaded_content = None | |||
| json_error = None | |||
| yaml_error = None | |||
| try: | |||
| loaded_content = json_loads(content) | |||
| except JSONDecodeError as e: | |||
| json_error = e | |||
| if loaded_content is None: | |||
| try: | |||
| loaded_content = safe_load(content) | |||
| except YAMLError as e: | |||
| yaml_error = e | |||
| if loaded_content is None: | |||
| raise ToolApiSchemaError( | |||
| f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}," | |||
| f" yaml error: {str(yaml_error)}" | |||
| ) | |||
| swagger_error = None | |||
| openapi_error = None | |||
| openapi_plugin_error = None | |||
| schema_type = None | |||
| try: | |||
| openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( | |||
| loaded_content, extra_info=extra_info, warning=warning | |||
| ) | |||
| schema_type = ApiProviderSchemaType.OPENAPI.value | |||
| return openapi, schema_type | |||
| except ToolApiSchemaError as e: | |||
| openapi_error = e | |||
| # openai parse error, fallback to swagger | |||
| try: | |||
| converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( | |||
| loaded_content, extra_info=extra_info, warning=warning | |||
| ) | |||
| schema_type = ApiProviderSchemaType.SWAGGER.value | |||
| return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( | |||
| converted_swagger, extra_info=extra_info, warning=warning | |||
| ), schema_type | |||
| except ToolApiSchemaError as e: | |||
| swagger_error = e | |||
| # swagger parse error, fallback to openai plugin | |||
| try: | |||
| openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( | |||
| json_dumps(loaded_content), extra_info=extra_info, warning=warning | |||
| ) | |||
| return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value | |||
| except ToolNotSupportedError as e: | |||
| # maybe it's not plugin at all | |||
| openapi_plugin_error = e | |||
| raise ToolApiSchemaError( | |||
| f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}," | |||
| f" openapi plugin error: {str(openapi_plugin_error)}" | |||
| ) | |||
| @@ -0,0 +1,17 @@ | |||
| import re | |||
| def remove_leading_symbols(text: str) -> str: | |||
| """ | |||
| Remove leading punctuation or symbols from the given text. | |||
| Args: | |||
| text (str): The input text to process. | |||
| Returns: | |||
| str: The text with leading punctuation or symbols removed. | |||
| """ | |||
| # Match Unicode ranges for punctuation and symbols | |||
| # FIXME this pattern is confused quick fix for #11868 maybe refactor it later | |||
| pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" | |||
| return re.sub(pattern, "", text) | |||
| @@ -0,0 +1,9 @@ | |||
| import uuid | |||
| def is_valid_uuid(uuid_str: str) -> bool: | |||
| try: | |||
| uuid.UUID(uuid_str) | |||
| return True | |||
| except Exception: | |||
| return False | |||
| @@ -0,0 +1,43 @@ | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any | |||
| from core.app.app_config.entities import VariableEntity | |||
| from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration | |||
| class WorkflowToolConfigurationUtils: | |||
| @classmethod | |||
| def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): | |||
| for configuration in configurations: | |||
| WorkflowToolParameterConfiguration.model_validate(configuration) | |||
| @classmethod | |||
| def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: | |||
| """ | |||
| get workflow graph variables | |||
| """ | |||
| nodes = graph.get("nodes", []) | |||
| start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) | |||
| if not start_node: | |||
| return [] | |||
| return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])] | |||
| @classmethod | |||
| def check_is_synced( | |||
| cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] | |||
| ): | |||
| """ | |||
| check is synced | |||
| raise ValueError if not synced | |||
| """ | |||
| variable_names = [variable.variable for variable in variables] | |||
| if len(tool_configurations) != len(variables): | |||
| raise ValueError("parameter configuration mismatch, please republish the tool to update") | |||
| for parameter in tool_configurations: | |||
| if parameter.name not in variable_names: | |||
| raise ValueError("parameter configuration mismatch, please republish the tool to update") | |||
| @@ -0,0 +1,35 @@ | |||
| import logging | |||
| from pathlib import Path | |||
| from typing import Any | |||
| import yaml # type: ignore | |||
| from yaml import YAMLError | |||
| logger = logging.getLogger(__name__) | |||
| def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any: | |||
| """ | |||
| Safe loading a YAML file | |||
| :param file_path: the path of the YAML file | |||
| :param ignore_error: | |||
| if True, return default_value if error occurs and the error will be logged in debug level | |||
| if False, raise error if error occurs | |||
| :param default_value: the value returned when errors ignored | |||
| :return: an object of the YAML content | |||
| """ | |||
| if not file_path or not Path(file_path).exists(): | |||
| if ignore_error: | |||
| return default_value | |||
| else: | |||
| raise FileNotFoundError(f"File not found: {file_path}") | |||
| with open(file_path, encoding="utf-8") as yaml_file: | |||
| try: | |||
| yaml_content = yaml.safe_load(yaml_file) | |||
| return yaml_content or default_value | |||
| except Exception as e: | |||
| if ignore_error: | |||
| return default_value | |||
| else: | |||
| raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e | |||
| @@ -0,0 +1,53 @@ | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Any | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import ( | |||
| DatasourceEntity, | |||
| DatasourceProviderType, | |||
| WebsiteCrawlMessage, | |||
| ) | |||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||
| class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): | |||
| tenant_id: str | |||
| icon: str | |||
| plugin_unique_identifier: str | |||
| entity: DatasourceEntity | |||
| runtime: DatasourceRuntime | |||
| def __init__( | |||
| self, | |||
| entity: DatasourceEntity, | |||
| runtime: DatasourceRuntime, | |||
| tenant_id: str, | |||
| icon: str, | |||
| plugin_unique_identifier: str, | |||
| ) -> None: | |||
| super().__init__(entity, runtime) | |||
| self.tenant_id = tenant_id | |||
| self.icon = icon | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| def get_website_crawl( | |||
| self, | |||
| user_id: str, | |||
| datasource_parameters: Mapping[str, Any], | |||
| provider_type: str, | |||
| ) -> Generator[WebsiteCrawlMessage, None, None]: | |||
| manager = PluginDatasourceManager() | |||
| return manager.get_website_crawl( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| datasource_parameters=datasource_parameters, | |||
| provider_type=provider_type, | |||
| ) | |||
| def datasource_provider_type(self) -> str: | |||
| return DatasourceProviderType.WEBSITE_CRAWL | |||
| @@ -0,0 +1,52 @@ | |||
| from core.datasource.__base.datasource_provider import DatasourcePluginProviderController | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType | |||
| from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin | |||
| class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController): | |||
| entity: DatasourceProviderEntityWithPlugin | |||
| plugin_id: str | |||
| plugin_unique_identifier: str | |||
| def __init__( | |||
| self, | |||
| entity: DatasourceProviderEntityWithPlugin, | |||
| plugin_id: str, | |||
| plugin_unique_identifier: str, | |||
| tenant_id: str, | |||
| ) -> None: | |||
| super().__init__(entity, tenant_id) | |||
| self.plugin_id = plugin_id | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| @property | |||
| def provider_type(self) -> DatasourceProviderType: | |||
| """ | |||
| returns the type of the provider | |||
| """ | |||
| return DatasourceProviderType.WEBSITE_CRAWL | |||
| def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore | |||
| """ | |||
| return datasource with given name | |||
| """ | |||
| datasource_entity = next( | |||
| ( | |||
| datasource_entity | |||
| for datasource_entity in self.entity.datasources | |||
| if datasource_entity.identity.name == datasource_name | |||
| ), | |||
| None, | |||
| ) | |||
| if not datasource_entity: | |||
| raise ValueError(f"Datasource with name {datasource_name} not found") | |||
| return WebsiteCrawlDatasourcePlugin( | |||
| entity=datasource_entity, | |||
| runtime=DatasourceRuntime(tenant_id=self.tenant_id), | |||
| tenant_id=self.tenant_id, | |||
| icon=self.entity.identity.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| @@ -17,3 +17,27 @@ class IndexingEstimate(BaseModel): | |||
| total_segments: int | |||
| preview: list[PreviewDetail] | |||
| qa_preview: Optional[list[QAPreviewDetail]] = None | |||
| class PipelineDataset(BaseModel): | |||
| id: str | |||
| name: str | |||
| description: str | |||
| chunk_structure: str | |||
| class PipelineDocument(BaseModel): | |||
| id: str | |||
| position: int | |||
| data_source_type: str | |||
| data_source_info: Optional[dict] = None | |||
| name: str | |||
| indexing_status: str | |||
| error: Optional[str] = None | |||
| enabled: bool | |||
| class PipelineGenerateResponse(BaseModel): | |||
| batch: str | |||
| dataset: PipelineDataset | |||
| documents: list[PipelineDocument] | |||
| @@ -28,7 +28,6 @@ from core.model_runtime.entities.provider_entities import ( | |||
| ) | |||
| from core.model_runtime.model_providers.__base.ai_model import AIModel | |||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from extensions.ext_database import db | |||
| from libs.datetime_utils import naive_utc_now | |||
| from models.provider import ( | |||
| @@ -41,6 +40,7 @@ from models.provider import ( | |||
| ProviderType, | |||
| TenantPreferredModelProvider, | |||
| ) | |||
| from models.provider_ids import ModelProviderID | |||
| logger = logging.getLogger(__name__) | |||
| @@ -627,6 +627,7 @@ class ProviderConfiguration(BaseModel): | |||
| Get custom model credentials. | |||
| """ | |||
| # get provider model | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| if model_provider_id.is_langgenius(): | |||
| @@ -1124,6 +1125,7 @@ class ProviderConfiguration(BaseModel): | |||
| """ | |||
| Get provider model setting. | |||
| """ | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| if model_provider_id.is_langgenius(): | |||
| @@ -1207,6 +1209,7 @@ class ProviderConfiguration(BaseModel): | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| model_provider_id = ModelProviderID(self.provider.provider) | |||
| provider_names = [self.provider.provider] | |||
| if model_provider_id.is_langgenius(): | |||
| @@ -1340,7 +1343,7 @@ class ProviderConfiguration(BaseModel): | |||
| """ | |||
| secret_input_form_variables = [] | |||
| for credential_form_schema in credential_form_schemas: | |||
| if credential_form_schema.type == FormType.SECRET_INPUT: | |||
| if credential_form_schema.type.value == FormType.SECRET_INPUT.value: | |||
| secret_input_form_variables.append(credential_form_schema.variable) | |||
| return secret_input_form_variables | |||
| @@ -0,0 +1,15 @@ | |||
| from typing import TYPE_CHECKING, Any, cast | |||
| from core.datasource import datasource_file_manager | |||
| from core.datasource.datasource_file_manager import DatasourceFileManager | |||
| if TYPE_CHECKING: | |||
| from core.datasource.datasource_file_manager import DatasourceFileManager | |||
| tool_file_manager: dict[str, Any] = {"manager": None} | |||
| class DatasourceFileParser: | |||
| @staticmethod | |||
| def get_datasource_file_manager() -> "DatasourceFileManager": | |||
| return cast("DatasourceFileManager", datasource_file_manager["manager"]) | |||
| @@ -20,6 +20,7 @@ class FileTransferMethod(StrEnum): | |||
| REMOTE_URL = "remote_url" | |||
| LOCAL_FILE = "local_file" | |||
| TOOL_FILE = "tool_file" | |||
| DATASOURCE_FILE = "datasource_file" | |||
| @staticmethod | |||
| def value_of(value): | |||
| @@ -97,7 +97,11 @@ def to_prompt_message_content( | |||
| def download(f: File, /): | |||
| if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE): | |||
| if f.transfer_method in ( | |||
| FileTransferMethod.TOOL_FILE, | |||
| FileTransferMethod.LOCAL_FILE, | |||
| FileTransferMethod.DATASOURCE_FILE, | |||
| ): | |||
| return _download_file_content(f._storage_key) | |||
| elif f.transfer_method == FileTransferMethod.REMOTE_URL: | |||
| response = ssrf_proxy.get(f.remote_url, follow_redirects=True) | |||
| @@ -115,11 +115,10 @@ class File(BaseModel): | |||
| if self.related_id is None: | |||
| raise ValueError("Missing file related_id") | |||
| return helpers.get_signed_file_url(upload_file_id=self.related_id) | |||
| elif self.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| elif self.transfer_method == FileTransferMethod.TOOL_FILE or self.transfer_method == FileTransferMethod.DATASOURCE_FILE: | |||
| assert self.related_id is not None | |||
| assert self.extension is not None | |||
| return sign_tool_file(tool_file_id=self.related_id, extension=self.extension) | |||
| def to_plugin_parameter(self) -> dict[str, Any]: | |||
| return { | |||
| "dify_model_identity": FILE_MODEL_IDENTITY, | |||
| @@ -12,8 +12,8 @@ def obfuscated_token(token: str): | |||
| def encrypt_token(tenant_id: str, token: str): | |||
| from extensions.ext_database import db | |||
| from models.account import Tenant | |||
| from models.engine import db | |||
| if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()): | |||
| raise ValueError(f"Tenant with id {tenant_id} not found") | |||
| @@ -0,0 +1,42 @@ | |||
| import logging | |||
| import re | |||
| from collections.abc import Sequence | |||
| from typing import Any | |||
| from core.tools.entities.tool_entities import CredentialType | |||
| logger = logging.getLogger(__name__) | |||
| def generate_provider_name( | |||
| providers: Sequence[Any], credential_type: CredentialType, fallback_context: str = "provider" | |||
| ) -> str: | |||
| try: | |||
| return generate_incremental_name( | |||
| [provider.name for provider in providers], | |||
| f"{credential_type.get_name()}", | |||
| ) | |||
| except Exception as e: | |||
| logger.warning("Error generating next provider name for %r: %r", fallback_context, e) | |||
| return f"{credential_type.get_name()} 1" | |||
| def generate_incremental_name( | |||
| names: Sequence[str], | |||
| default_pattern: str, | |||
| ) -> str: | |||
| pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$" | |||
| numbers = [] | |||
| for name in names: | |||
| if not name: | |||
| continue | |||
| match = re.match(pattern, name.strip()) | |||
| if match: | |||
| numbers.append(int(match.group(1))) | |||
| if not numbers: | |||
| return f"{default_pattern} 1" | |||
| max_number = max(numbers) | |||
| return f"{default_pattern} {max_number + 1}" | |||
| @@ -359,6 +359,7 @@ class IndexingRunner: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| notion_info={ | |||
| "credential_id": data_source_info["credential_id"], | |||
| "notion_workspace_id": data_source_info["notion_workspace_id"], | |||
| "notion_obj_id": data_source_info["notion_page_id"], | |||
| "notion_page_type": data_source_info["type"], | |||
| @@ -28,9 +28,10 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | |||
| from core.ops.utils import measure_time | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.graph_engine.entities.event import AgentLogEvent | |||
| from core.workflow.node_events import AgentLogEvent | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from models import App, Message, WorkflowNodeExecutionModel, db | |||
| from models import App, Message, WorkflowNodeExecutionModel | |||
| logger = logging.getLogger(__name__) | |||
| @@ -2,6 +2,7 @@ from collections.abc import Sequence | |||
| from typing import Optional | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| from core.file import file_manager | |||
| @@ -32,7 +33,12 @@ class TokenBufferMemory: | |||
| self.model_instance = model_instance | |||
| def _build_prompt_message_with_files( | |||
| self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool | |||
| self, | |||
| message_files: Sequence[MessageFile], | |||
| text_content: str, | |||
| message: Message, | |||
| app_record, | |||
| is_user_message: bool, | |||
| ) -> PromptMessage: | |||
| """ | |||
| Build prompt message with files. | |||
| @@ -98,80 +104,80 @@ class TokenBufferMemory: | |||
| :param max_token_limit: max token limit | |||
| :param message_limit: message limit | |||
| """ | |||
| app_record = self.conversation.app | |||
| # fetch limited messages, and return reversed | |||
| stmt = ( | |||
| select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc()) | |||
| ) | |||
| with Session(db.engine) as session: | |||
| app_record = self.conversation.app | |||
| # fetch limited messages, and return reversed | |||
| stmt = ( | |||
| select(Message) | |||
| .where(Message.conversation_id == self.conversation.id) | |||
| .order_by(Message.created_at.desc()) | |||
| ) | |||
| if message_limit and message_limit > 0: | |||
| message_limit = min(message_limit, 500) | |||
| else: | |||
| message_limit = 500 | |||
| if message_limit and message_limit > 0: | |||
| message_limit = min(message_limit, 500) | |||
| else: | |||
| message_limit = 500 | |||
| stmt = stmt.limit(message_limit) | |||
| stmt = stmt.limit(message_limit) | |||
| messages = db.session.scalars(stmt).all() | |||
| messages = session.scalars(stmt).all() | |||
| # instead of all messages from the conversation, we only need to extract messages | |||
| # that belong to the thread of last message | |||
| thread_messages = extract_thread_messages(messages) | |||
| # instead of all messages from the conversation, we only need to extract messages | |||
| # that belong to the thread of last message | |||
| thread_messages = extract_thread_messages(messages) | |||
| # for newly created message, its answer is temporarily empty, we don't need to add it to memory | |||
| if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0: | |||
| thread_messages.pop(0) | |||
| # for newly created message, its answer is temporarily empty, we don't need to add it to memory | |||
| if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0: | |||
| thread_messages.pop(0) | |||
| messages = list(reversed(thread_messages)) | |||
| messages = list(reversed(thread_messages)) | |||
| prompt_messages: list[PromptMessage] = [] | |||
| for message in messages: | |||
| # Process user message with files | |||
| user_files = ( | |||
| db.session.query(MessageFile) | |||
| .where( | |||
| prompt_messages: list[PromptMessage] = [] | |||
| for message in messages: | |||
| # Process user message with files | |||
| user_file_query = select(MessageFile).where( | |||
| MessageFile.message_id == message.id, | |||
| (MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)), | |||
| ) | |||
| .all() | |||
| ) | |||
| if user_files: | |||
| user_prompt_message = self._build_prompt_message_with_files( | |||
| message_files=user_files, | |||
| text_content=message.query, | |||
| message=message, | |||
| app_record=app_record, | |||
| is_user_message=True, | |||
| user_files = session.scalars(user_file_query).all() | |||
| if user_files: | |||
| user_prompt_message = self._build_prompt_message_with_files( | |||
| message_files=user_files, | |||
| text_content=message.query, | |||
| message=message, | |||
| app_record=app_record, | |||
| is_user_message=True, | |||
| ) | |||
| prompt_messages.append(user_prompt_message) | |||
| else: | |||
| prompt_messages.append(UserPromptMessage(content=message.query)) | |||
| # Process assistant message with files | |||
| assistant_file_query = select(MessageFile).where( | |||
| MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant" | |||
| ) | |||
| prompt_messages.append(user_prompt_message) | |||
| else: | |||
| prompt_messages.append(UserPromptMessage(content=message.query)) | |||
| # Process assistant message with files | |||
| assistant_files = ( | |||
| db.session.query(MessageFile) | |||
| .where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant") | |||
| .all() | |||
| ) | |||
| if assistant_files: | |||
| assistant_prompt_message = self._build_prompt_message_with_files( | |||
| message_files=assistant_files, | |||
| text_content=message.answer, | |||
| message=message, | |||
| app_record=app_record, | |||
| is_user_message=False, | |||
| ) | |||
| prompt_messages.append(assistant_prompt_message) | |||
| else: | |||
| prompt_messages.append(AssistantPromptMessage(content=message.answer)) | |||
| if not prompt_messages: | |||
| return [] | |||
| # prune the chat message if it exceeds the max token limit | |||
| curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) | |||
| assistant_files = session.scalars(assistant_file_query).all() | |||
| if assistant_files: | |||
| assistant_prompt_message = self._build_prompt_message_with_files( | |||
| message_files=assistant_files, | |||
| text_content=message.answer, | |||
| message=message, | |||
| app_record=app_record, | |||
| is_user_message=False, | |||
| ) | |||
| prompt_messages.append(assistant_prompt_message) | |||
| else: | |||
| prompt_messages.append(AssistantPromptMessage(content=message.answer)) | |||
| if not prompt_messages: | |||
| return [] | |||
| # prune the chat message if it exceeds the max token limit | |||
| curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) | |||
| if curr_message_tokens > max_token_limit: | |||
| while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: | |||
| @@ -24,8 +24,7 @@ from core.model_runtime.errors.invoke import ( | |||
| InvokeRateLimitError, | |||
| InvokeServerUnavailableError, | |||
| ) | |||
| from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity | |||
| from core.plugin.impl.model import PluginModelClient | |||
| from core.plugin.entities.plugin_daemon import PluginModelProviderEntity | |||
| class AIModel(BaseModel): | |||
| @@ -53,6 +52,8 @@ class AIModel(BaseModel): | |||
| :return: Invoke error mapping | |||
| """ | |||
| from core.plugin.entities.plugin_daemon import PluginDaemonInnerError | |||
| return { | |||
| InvokeConnectionError: [InvokeConnectionError], | |||
| InvokeServerUnavailableError: [InvokeServerUnavailableError], | |||
| @@ -140,6 +141,8 @@ class AIModel(BaseModel): | |||
| :param credentials: model credentials | |||
| :return: model schema | |||
| """ | |||
| from core.plugin.impl.model import PluginModelClient | |||
| plugin_model_manager = PluginModelClient() | |||
| cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" | |||
| # sort credentials | |||
| @@ -22,7 +22,6 @@ from core.model_runtime.entities.model_entities import ( | |||
| PriceType, | |||
| ) | |||
| from core.model_runtime.model_providers.__base.ai_model import AIModel | |||
| from core.plugin.impl.model import PluginModelClient | |||
| logger = logging.getLogger(__name__) | |||
| @@ -142,6 +141,8 @@ class LargeLanguageModel(AIModel): | |||
| result: Union[LLMResult, Generator[LLMResultChunk, None, None]] | |||
| try: | |||
| from core.plugin.impl.model import PluginModelClient | |||
| plugin_model_manager = PluginModelClient() | |||
| result = plugin_model_manager.invoke_llm( | |||
| tenant_id=self.tenant_id, | |||
| @@ -340,6 +341,8 @@ class LargeLanguageModel(AIModel): | |||
| :return: | |||
| """ | |||
| if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: | |||
| from core.plugin.impl.model import PluginModelClient | |||
| plugin_model_manager = PluginModelClient() | |||
| return plugin_model_manager.get_llm_num_tokens( | |||
| tenant_id=self.tenant_id, | |||