| @@ -14,7 +14,7 @@ from sqlalchemy.exc import SQLAlchemyError | |||
| from configs import dify_config | |||
| from constants.languages import languages | |||
| from core.helper import encrypter | |||
| from core.plugin.entities.plugin import DatasourceProviderID, PluginInstallationSource | |||
| from core.plugin.entities.plugin import PluginInstallationSource | |||
| from core.plugin.impl.plugin import PluginInstaller | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from core.rag.datasource.vdb.vector_type import VectorType | |||
| @@ -35,7 +35,7 @@ from models.dataset import Document as DatasetDocument | |||
| from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation | |||
| from models.oauth import DatasourceOauthParamConfig, DatasourceProvider | |||
| from models.provider import Provider, ProviderModel | |||
| from models.provider_ids import ToolProviderID | |||
| from models.provider_ids import DatasourceProviderID, ToolProviderID | |||
| from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding | |||
| from models.tools import ToolOAuthSystemClient | |||
| from services.account_service import AccountService, RegisterService, TenantService | |||
| @@ -11,10 +11,10 @@ from controllers.console.wraps import ( | |||
| setup_required, | |||
| ) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.plugin.entities.plugin import DatasourceProviderID | |||
| from core.plugin.impl.oauth import OAuthHandler | |||
| from libs.helper import StrLen | |||
| from libs.login import login_required | |||
| from models.provider_ids import DatasourceProviderID | |||
| from services.datasource_provider_service import DatasourceProviderService | |||
| from services.plugin.oauth_service import OAuthProxyService | |||
| @@ -17,10 +17,11 @@ from core.variables.segment_group import SegmentGroup | |||
| from core.variables.segments import ArrayFileSegment, FileSegment, Segment | |||
| from core.variables.types import SegmentType | |||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | |||
| from extensions.ext_database import db | |||
| from factories.file_factory import build_from_mapping, build_from_mappings | |||
| from factories.variable_factory import build_segment_with_type | |||
| from libs.login import current_user, login_required | |||
| from models import db | |||
| from models.account import Account | |||
| from models.dataset import Pipeline | |||
| from models.workflow import WorkflowDraftVariable | |||
| from services.rag_pipeline.rag_pipeline import RagPipelineService | |||
| @@ -131,7 +132,7 @@ def _api_prerequisite(f): | |||
| @account_initialization_required | |||
| @get_rag_pipeline | |||
| def wrapper(*args, **kwargs): | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| return f(*args, **kwargs) | |||
| @@ -62,7 +62,7 @@ class DraftRagPipelineApi(Resource): | |||
| Get draft rag pipeline's workflow | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| # fetch draft workflow by app_model | |||
| @@ -84,7 +84,7 @@ class DraftRagPipelineApi(Resource): | |||
| Sync draft workflow | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| content_type = request.headers.get("Content-Type", "") | |||
| @@ -161,7 +161,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource): | |||
| Run draft workflow iteration node | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| @@ -198,7 +198,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource): | |||
| Run draft workflow loop node | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| @@ -235,7 +235,7 @@ class DraftRagPipelineRunApi(Resource): | |||
| Run draft workflow | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| @@ -272,7 +272,7 @@ class PublishedRagPipelineRunApi(Resource): | |||
| Run published workflow | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| @@ -384,8 +384,6 @@ class PublishedRagPipelineRunApi(Resource): | |||
| # | |||
| # return result | |||
| # | |||
| class RagPipelinePublishedDatasourceNodeRunApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -396,7 +394,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): | |||
| Run rag pipeline datasource | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| @@ -441,10 +439,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): | |||
| Run rag pipeline datasource | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| @@ -487,10 +482,7 @@ class RagPipelineDraftNodeRunApi(Resource): | |||
| Run draft workflow node | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| @@ -519,7 +511,7 @@ class RagPipelineTaskStopApi(Resource): | |||
| Stop workflow task | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) | |||
| @@ -538,7 +530,7 @@ class PublishedRagPipelineApi(Resource): | |||
| Get published pipeline | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not pipeline.is_published: | |||
| return None | |||
| @@ -558,10 +550,7 @@ class PublishedRagPipelineApi(Resource): | |||
| Publish workflow | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| rag_pipeline_service = RagPipelineService() | |||
| @@ -595,7 +584,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource): | |||
| Get default block config | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| # Get default block configs | |||
| @@ -613,7 +602,7 @@ class DefaultRagPipelineBlockConfigApi(Resource): | |||
| Get default block config | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| @@ -659,7 +648,7 @@ class PublishedAllRagPipelineApi(Resource): | |||
| """ | |||
| Get published workflows | |||
| """ | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| @@ -708,10 +697,7 @@ class RagPipelineByIdApi(Resource): | |||
| Update workflow attributes | |||
| """ | |||
| # Check permission | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| @@ -767,7 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource): | |||
| Get second step parameters of rag pipeline | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("node_id", type=str, required=True, location="args") | |||
| @@ -792,7 +778,7 @@ class PublishedRagPipelineFirstStepApi(Resource): | |||
| Get first step parameters of rag pipeline | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("node_id", type=str, required=True, location="args") | |||
| @@ -817,7 +803,7 @@ class DraftRagPipelineFirstStepApi(Resource): | |||
| Get first step parameters of rag pipeline | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("node_id", type=str, required=True, location="args") | |||
| @@ -842,7 +828,7 @@ class DraftRagPipelineSecondStepApi(Resource): | |||
| Get second step parameters of rag pipeline | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("node_id", type=str, required=True, location="args") | |||
| @@ -926,8 +912,11 @@ class DatasourceListApi(Resource): | |||
| @account_initialization_required | |||
| def get(self): | |||
| user = current_user | |||
| if not isinstance(user, Account): | |||
| raise Forbidden() | |||
| tenant_id = user.current_tenant_id | |||
| if not tenant_id: | |||
| raise Forbidden() | |||
| return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id)) | |||
| @@ -974,10 +963,7 @@ class RagPipelineDatasourceVariableApi(Resource): | |||
| """ | |||
| Set datasource variables | |||
| """ | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| if not isinstance(current_user, Account) or not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| @@ -5,6 +5,7 @@ from typing import Optional | |||
| from controllers.console.datasets.error import PipelineNotFoundError | |||
| from extensions.ext_database import db | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.dataset import Pipeline | |||
| @@ -17,6 +18,9 @@ def get_rag_pipeline( | |||
| if not kwargs.get("pipeline_id"): | |||
| raise ValueError("missing pipeline_id in path parameters") | |||
| if not isinstance(current_user, Account): | |||
| raise ValueError("current_user is not an account") | |||
| pipeline_id = kwargs.get("pipeline_id") | |||
| pipeline_id = str(pipeline_id) | |||
| @@ -32,4 +32,4 @@ class SpecSchemaDefinitionsApi(Resource): | |||
| return [], 200 | |||
| api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions") | |||
| api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions") | |||
| @@ -133,6 +133,9 @@ class DocumentAddByTextApi(DatasetApiResource): | |||
| # validate args | |||
| DocumentService.document_create_args_validate(knowledge_config) | |||
| if not current_user: | |||
| raise ValueError("current_user is required") | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id( | |||
| dataset=dataset, | |||
| @@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner): | |||
| tenant_id=tenant_id, | |||
| dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [], | |||
| retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, | |||
| return_resource=app_config.additional_features.show_retrieve_source, | |||
| return_resource=( | |||
| app_config.additional_features.show_retrieve_source if app_config.additional_features else False | |||
| ), | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| hit_callback=hit_callback, | |||
| user_id=user_id, | |||
| @@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| if invoke_from == InvokeFrom.DEBUGGER: | |||
| # always enable retriever resource in debugger mode | |||
| app_config.additional_features.show_retrieve_source = True | |||
| app_config.additional_features.show_retrieve_source = True # type: ignore | |||
| workflow_run_id = str(uuid.uuid4()) | |||
| # init application generate entity | |||
| @@ -162,7 +162,9 @@ class ChatAppRunner(AppRunner): | |||
| config=app_config.dataset, | |||
| query=query, | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| show_retrieve_source=app_config.additional_features.show_retrieve_source, | |||
| show_retrieve_source=( | |||
| app_config.additional_features.show_retrieve_source if app_config.additional_features else False | |||
| ), | |||
| hit_callback=hit_callback, | |||
| memory=memory, | |||
| message_id=message.id, | |||
| @@ -36,8 +36,8 @@ from core.app.entities.task_entities import ( | |||
| WorkflowStartStreamResponse, | |||
| ) | |||
| from core.file import FILE_MODEL_IDENTITY, File | |||
| from core.tools.entities.tool_entities import ToolProviderType | |||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||
| from core.tools.entities.tool_entities import ToolProviderType | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.variables.segments import ArrayFileSegment, FileSegment, Segment | |||
| from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution | |||
| @@ -1,8 +1,7 @@ | |||
| import logging | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional, cast | |||
| import time | |||
| from typing import Optional, cast | |||
| from configs import dify_config | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig | |||
| from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner | |||
| @@ -11,10 +10,12 @@ from core.app.entities.app_invoke_entities import ( | |||
| RagPipelineGenerateEntity, | |||
| ) | |||
| from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput | |||
| from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback | |||
| from core.workflow.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent | |||
| from core.workflow.nodes.node_factory import DifyNodeFactory | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import VariableLoader | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| @@ -22,7 +23,7 @@ from extensions.ext_database import db | |||
| from models.dataset import Document, Pipeline | |||
| from models.enums import UserFrom | |||
| from models.model import EndUser | |||
| from models.workflow import Workflow, WorkflowType | |||
| from models.workflow import Workflow | |||
| logger = logging.getLogger(__name__) | |||
| @@ -84,24 +85,30 @@ class PipelineRunner(WorkflowBasedAppRunner): | |||
| db.session.close() | |||
| workflow_callbacks: list[WorkflowCallback] = [] | |||
| if dify_config.DEBUG: | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| # if only single iteration run is requested | |||
| if self.application_generate_entity.single_iteration_run: | |||
| graph_runtime_state = GraphRuntimeState( | |||
| variable_pool=VariablePool.empty(), | |||
| start_at=time.time(), | |||
| ) | |||
| # if only single iteration run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | |||
| workflow=workflow, | |||
| node_id=self.application_generate_entity.single_iteration_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_iteration_run.inputs, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| elif self.application_generate_entity.single_loop_run: | |||
| graph_runtime_state = GraphRuntimeState( | |||
| variable_pool=VariablePool.empty(), | |||
| start_at=time.time(), | |||
| ) | |||
| # if only single loop run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( | |||
| workflow=workflow, | |||
| node_id=self.application_generate_entity.single_loop_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_loop_run.inputs, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| else: | |||
| inputs = self.application_generate_entity.inputs | |||
| @@ -121,6 +128,7 @@ class PipelineRunner(WorkflowBasedAppRunner): | |||
| datasource_info=self.application_generate_entity.datasource_info, | |||
| invoke_from=self.application_generate_entity.invoke_from.value, | |||
| ) | |||
| rag_pipeline_variables = [] | |||
| if workflow.rag_pipeline_variables: | |||
| for v in workflow.rag_pipeline_variables: | |||
| @@ -143,11 +151,13 @@ class PipelineRunner(WorkflowBasedAppRunner): | |||
| conversation_variables=[], | |||
| rag_pipeline_variables=rag_pipeline_variables, | |||
| ) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | |||
| # init graph | |||
| graph = self._init_rag_pipeline_graph( | |||
| graph_config=workflow.graph_dict, | |||
| graph_runtime_state=graph_runtime_state, | |||
| start_node_id=self.application_generate_entity.start_node_id, | |||
| workflow=workflow, | |||
| ) | |||
| # RUN WORKFLOW | |||
| @@ -155,7 +165,6 @@ class PipelineRunner(WorkflowBasedAppRunner): | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=workflow.app_id, | |||
| workflow_id=workflow.id, | |||
| workflow_type=WorkflowType.value_of(workflow.type), | |||
| graph=graph, | |||
| graph_config=workflow.graph_dict, | |||
| user_id=self.application_generate_entity.user_id, | |||
| @@ -166,11 +175,10 @@ class PipelineRunner(WorkflowBasedAppRunner): | |||
| ), | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| call_depth=self.application_generate_entity.call_depth, | |||
| variable_pool=variable_pool, | |||
| thread_pool_id=self.workflow_thread_pool_id, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| generator = workflow_entry.run(callbacks=workflow_callbacks) | |||
| generator = workflow_entry.run() | |||
| for event in generator: | |||
| self._update_document_status( | |||
| @@ -194,10 +202,13 @@ class PipelineRunner(WorkflowBasedAppRunner): | |||
| # return workflow | |||
| return workflow | |||
| def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph: | |||
| def _init_rag_pipeline_graph( | |||
| self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None | |||
| ) -> Graph: | |||
| """ | |||
| Init pipeline graph | |||
| """ | |||
| graph_config = workflow.graph_dict | |||
| if "nodes" not in graph_config or "edges" not in graph_config: | |||
| raise ValueError("nodes or edges not found in workflow graph") | |||
| @@ -227,7 +238,23 @@ class PipelineRunner(WorkflowBasedAppRunner): | |||
| graph_config["nodes"] = real_run_nodes | |||
| graph_config["edges"] = real_edges | |||
| # init graph | |||
| graph = Graph.init(graph_config=graph_config) | |||
| # Create required parameters for Graph.init | |||
| graph_init_params = GraphInitParams( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=self._app_id, | |||
| workflow_id=workflow.id, | |||
| graph_config=graph_config, | |||
| user_id="", | |||
| user_from=UserFrom.ACCOUNT.value, | |||
| invoke_from=InvokeFrom.SERVICE_API.value, | |||
| call_depth=0, | |||
| ) | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=graph_init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id) | |||
| if not graph: | |||
| raise ValueError("graph not found in workflow") | |||
| @@ -10,13 +10,13 @@ from core.datasource.entities.datasource_entities import ( | |||
| OnlineDriveDownloadFileRequest, | |||
| WebsiteCrawlMessage, | |||
| ) | |||
| from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID | |||
| from core.plugin.entities.plugin_daemon import ( | |||
| PluginBasicBooleanResponse, | |||
| PluginDatasourceProviderEntity, | |||
| ) | |||
| from core.plugin.impl.base import BasePluginClient | |||
| from core.schemas.resolver import resolve_dify_schema_refs | |||
| from models.provider_ids import DatasourceProviderID, GenericProviderID | |||
| from services.tools.tools_transform_service import ToolTransformService | |||
| @@ -2,7 +2,7 @@ | |||
| from abc import ABC, abstractmethod | |||
| from collections.abc import Mapping | |||
| from typing import Any, TYPE_CHECKING, Optional | |||
| from typing import TYPE_CHECKING, Any, Optional | |||
| from configs import dify_config | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| @@ -2,4 +2,4 @@ | |||
| from .resolver import resolve_dify_schema_refs | |||
| __all__ = ["resolve_dify_schema_refs"] | |||
| __all__ = ["resolve_dify_schema_refs"] | |||
| @@ -7,7 +7,7 @@ from typing import Any, ClassVar, Optional | |||
| class SchemaRegistry: | |||
| """Schema registry manages JSON schemas with version support""" | |||
| _default_instance: ClassVar[Optional["SchemaRegistry"]] = None | |||
| _lock: ClassVar[threading.Lock] = threading.Lock() | |||
| @@ -25,41 +25,41 @@ class SchemaRegistry: | |||
| if cls._default_instance is None: | |||
| current_dir = Path(__file__).parent | |||
| schema_dir = current_dir / "builtin" / "schemas" | |||
| registry = cls(str(schema_dir)) | |||
| registry.load_all_versions() | |||
| cls._default_instance = registry | |||
| return cls._default_instance | |||
| def load_all_versions(self) -> None: | |||
| """Scans the schema directory and loads all versions""" | |||
| if not self.base_dir.exists(): | |||
| return | |||
| for entry in self.base_dir.iterdir(): | |||
| if not entry.is_dir(): | |||
| continue | |||
| version = entry.name | |||
| if not version.startswith("v"): | |||
| continue | |||
| self._load_version_dir(version, entry) | |||
| def _load_version_dir(self, version: str, version_dir: Path) -> None: | |||
| """Loads all schemas in a version directory""" | |||
| if not version_dir.exists(): | |||
| return | |||
| if version not in self.versions: | |||
| self.versions[version] = {} | |||
| for entry in version_dir.iterdir(): | |||
| if entry.suffix != ".json": | |||
| continue | |||
| schema_name = entry.stem | |||
| self._load_schema(version, schema_name, entry) | |||
| @@ -68,10 +68,10 @@ class SchemaRegistry: | |||
| try: | |||
| with open(schema_path, encoding="utf-8") as f: | |||
| schema = json.load(f) | |||
| # Store the schema | |||
| self.versions[version][schema_name] = schema | |||
| # Extract and store metadata | |||
| uri = f"https://dify.ai/schemas/{version}/{schema_name}.json" | |||
| metadata = { | |||
| @@ -81,26 +81,26 @@ class SchemaRegistry: | |||
| "deprecated": schema.get("deprecated", False), | |||
| } | |||
| self.metadata[uri] = metadata | |||
| except (OSError, json.JSONDecodeError) as e: | |||
| print(f"Warning: failed to load schema {version}/{schema_name}: {e}") | |||
| def get_schema(self, uri: str) -> Optional[Any]: | |||
| """Retrieves a schema by URI with version support""" | |||
| version, schema_name = self._parse_uri(uri) | |||
| if not version or not schema_name: | |||
| return None | |||
| version_schemas = self.versions.get(version) | |||
| if not version_schemas: | |||
| return None | |||
| return version_schemas.get(schema_name) | |||
| def _parse_uri(self, uri: str) -> tuple[str, str]: | |||
| """Parses a schema URI to extract version and schema name""" | |||
| from core.schemas.resolver import parse_dify_schema_uri | |||
| return parse_dify_schema_uri(uri) | |||
| def list_versions(self) -> list[str]: | |||
| @@ -112,19 +112,15 @@ class SchemaRegistry: | |||
| version_schemas = self.versions.get(version) | |||
| if not version_schemas: | |||
| return [] | |||
| return sorted(version_schemas.keys()) | |||
| def get_all_schemas_for_version(self, version: str = "v1") -> list[Mapping[str, Any]]: | |||
| """Returns all schemas for a version in the API format""" | |||
| version_schemas = self.versions.get(version, {}) | |||
| result = [] | |||
| for schema_name, schema in version_schemas.items(): | |||
| result.append({ | |||
| "name": schema_name, | |||
| "label": schema.get("title", schema_name), | |||
| "schema": schema | |||
| }) | |||
| return result | |||
| result.append({"name": schema_name, "label": schema.get("title", schema_name), "schema": schema}) | |||
| return result | |||
| @@ -19,11 +19,13 @@ _DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$ | |||
| class SchemaResolutionError(Exception): | |||
| """Base exception for schema resolution errors""" | |||
| pass | |||
| class CircularReferenceError(SchemaResolutionError): | |||
| """Raised when a circular reference is detected""" | |||
| def __init__(self, ref_uri: str, ref_path: list[str]): | |||
| self.ref_uri = ref_uri | |||
| self.ref_path = ref_path | |||
| @@ -32,6 +34,7 @@ class CircularReferenceError(SchemaResolutionError): | |||
| class MaxDepthExceededError(SchemaResolutionError): | |||
| """Raised when maximum resolution depth is exceeded""" | |||
| def __init__(self, max_depth: int): | |||
| self.max_depth = max_depth | |||
| super().__init__(f"Maximum resolution depth ({max_depth}) exceeded") | |||
| @@ -39,6 +42,7 @@ class MaxDepthExceededError(SchemaResolutionError): | |||
| class SchemaNotFoundError(SchemaResolutionError): | |||
| """Raised when a referenced schema cannot be found""" | |||
| def __init__(self, ref_uri: str): | |||
| self.ref_uri = ref_uri | |||
| super().__init__(f"Schema not found: {ref_uri}") | |||
| @@ -47,6 +51,7 @@ class SchemaNotFoundError(SchemaResolutionError): | |||
| @dataclass | |||
| class QueueItem: | |||
| """Represents an item in the BFS queue""" | |||
| current: Any | |||
| parent: Optional[Any] | |||
| key: Optional[Union[str, int]] | |||
| @@ -56,39 +61,39 @@ class QueueItem: | |||
| class SchemaResolver: | |||
| """Resolver for Dify schema references with caching and optimizations""" | |||
| _cache: dict[str, SchemaDict] = {} | |||
| _cache_lock = threading.Lock() | |||
| def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10): | |||
| """ | |||
| Initialize the schema resolver | |||
| Args: | |||
| registry: Schema registry to use (defaults to default registry) | |||
| max_depth: Maximum depth for reference resolution | |||
| """ | |||
| self.registry = registry or SchemaRegistry.default_registry() | |||
| self.max_depth = max_depth | |||
| @classmethod | |||
| def clear_cache(cls) -> None: | |||
| """Clear the global schema cache""" | |||
| with cls._cache_lock: | |||
| cls._cache.clear() | |||
| def resolve(self, schema: SchemaType) -> SchemaType: | |||
| """ | |||
| Resolve all $ref references in the schema | |||
| Performance optimization: quickly checks for $ref presence before processing. | |||
| Args: | |||
| schema: Schema to resolve | |||
| Returns: | |||
| Resolved schema with all references expanded | |||
| Raises: | |||
| CircularReferenceError: If circular reference detected | |||
| MaxDepthExceededError: If max depth exceeded | |||
| @@ -96,44 +101,39 @@ class SchemaResolver: | |||
| """ | |||
| if not isinstance(schema, (dict, list)): | |||
| return schema | |||
| # Fast path: if no Dify refs found, return original schema unchanged | |||
| # This avoids expensive deepcopy and BFS traversal for schemas without refs | |||
| if not _has_dify_refs(schema): | |||
| return schema | |||
| # Slow path: schema contains refs, perform full resolution | |||
| import copy | |||
| result = copy.deepcopy(schema) | |||
| # Initialize BFS queue | |||
| queue = deque([QueueItem( | |||
| current=result, | |||
| parent=None, | |||
| key=None, | |||
| depth=0, | |||
| ref_path=set() | |||
| )]) | |||
| queue = deque([QueueItem(current=result, parent=None, key=None, depth=0, ref_path=set())]) | |||
| while queue: | |||
| item = queue.popleft() | |||
| # Process the current item | |||
| self._process_queue_item(queue, item) | |||
| return result | |||
| def _process_queue_item(self, queue: deque, item: QueueItem) -> None: | |||
| """Process a single queue item""" | |||
| if isinstance(item.current, dict): | |||
| self._process_dict(queue, item) | |||
| elif isinstance(item.current, list): | |||
| self._process_list(queue, item) | |||
| def _process_dict(self, queue: deque, item: QueueItem) -> None: | |||
| """Process a dictionary item""" | |||
| ref_uri = item.current.get("$ref") | |||
| if ref_uri and _is_dify_schema_ref(ref_uri): | |||
| # Handle $ref resolution | |||
| self._resolve_ref(queue, item, ref_uri) | |||
| @@ -144,14 +144,10 @@ class SchemaResolver: | |||
| next_depth = item.depth + 1 | |||
| if next_depth >= self.max_depth: | |||
| raise MaxDepthExceededError(self.max_depth) | |||
| queue.append(QueueItem( | |||
| current=value, | |||
| parent=item.current, | |||
| key=key, | |||
| depth=next_depth, | |||
| ref_path=item.ref_path | |||
| )) | |||
| queue.append( | |||
| QueueItem(current=value, parent=item.current, key=key, depth=next_depth, ref_path=item.ref_path) | |||
| ) | |||
| def _process_list(self, queue: deque, item: QueueItem) -> None: | |||
| """Process a list item""" | |||
| for idx, value in enumerate(item.current): | |||
| @@ -159,14 +155,10 @@ class SchemaResolver: | |||
| next_depth = item.depth + 1 | |||
| if next_depth >= self.max_depth: | |||
| raise MaxDepthExceededError(self.max_depth) | |||
| queue.append(QueueItem( | |||
| current=value, | |||
| parent=item.current, | |||
| key=idx, | |||
| depth=next_depth, | |||
| ref_path=item.ref_path | |||
| )) | |||
| queue.append( | |||
| QueueItem(current=value, parent=item.current, key=idx, depth=next_depth, ref_path=item.ref_path) | |||
| ) | |||
| def _resolve_ref(self, queue: deque, item: QueueItem, ref_uri: str) -> None: | |||
| """Resolve a $ref reference""" | |||
| # Check for circular reference | |||
| @@ -175,82 +167,78 @@ class SchemaResolver: | |||
| item.current["$circular_ref"] = True | |||
| logger.warning("Circular reference detected: %s", ref_uri) | |||
| return | |||
| # Get resolved schema (from cache or registry) | |||
| resolved_schema = self._get_resolved_schema(ref_uri) | |||
| if not resolved_schema: | |||
| logger.warning("Schema not found: %s", ref_uri) | |||
| return | |||
| # Update ref path | |||
| new_ref_path = item.ref_path | {ref_uri} | |||
| # Replace the reference with resolved schema | |||
| next_depth = item.depth + 1 | |||
| if next_depth >= self.max_depth: | |||
| raise MaxDepthExceededError(self.max_depth) | |||
| if item.parent is None: | |||
| # Root level replacement | |||
| item.current.clear() | |||
| item.current.update(resolved_schema) | |||
| queue.append(QueueItem( | |||
| current=item.current, | |||
| parent=None, | |||
| key=None, | |||
| depth=next_depth, | |||
| ref_path=new_ref_path | |||
| )) | |||
| queue.append( | |||
| QueueItem(current=item.current, parent=None, key=None, depth=next_depth, ref_path=new_ref_path) | |||
| ) | |||
| else: | |||
| # Update parent container | |||
| item.parent[item.key] = resolved_schema.copy() | |||
| queue.append(QueueItem( | |||
| current=item.parent[item.key], | |||
| parent=item.parent, | |||
| key=item.key, | |||
| depth=next_depth, | |||
| ref_path=new_ref_path | |||
| )) | |||
| queue.append( | |||
| QueueItem( | |||
| current=item.parent[item.key], | |||
| parent=item.parent, | |||
| key=item.key, | |||
| depth=next_depth, | |||
| ref_path=new_ref_path, | |||
| ) | |||
| ) | |||
| def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]: | |||
| """Get resolved schema from cache or registry""" | |||
| # Check cache first | |||
| with self._cache_lock: | |||
| if ref_uri in self._cache: | |||
| return self._cache[ref_uri].copy() | |||
| # Fetch from registry | |||
| schema = self.registry.get_schema(ref_uri) | |||
| if not schema: | |||
| return None | |||
| # Clean and cache | |||
| cleaned = _remove_metadata_fields(schema) | |||
| with self._cache_lock: | |||
| self._cache[ref_uri] = cleaned | |||
| return cleaned.copy() | |||
| def resolve_dify_schema_refs( | |||
| schema: SchemaType, | |||
| registry: Optional[SchemaRegistry] = None, | |||
| max_depth: int = 30 | |||
| schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30 | |||
| ) -> SchemaType: | |||
| """ | |||
| Resolve $ref references in Dify schema to actual schema content | |||
| This is a convenience function that creates a resolver and resolves the schema. | |||
| Performance optimization: quickly checks for $ref presence before processing. | |||
| Args: | |||
| schema: Schema object that may contain $ref references | |||
| registry: Optional schema registry, defaults to default registry | |||
| max_depth: Maximum depth to prevent infinite loops (default: 30) | |||
| Returns: | |||
| Schema with all $ref references resolved to actual content | |||
| Raises: | |||
| CircularReferenceError: If circular reference detected | |||
| MaxDepthExceededError: If maximum depth exceeded | |||
| @@ -260,7 +248,7 @@ def resolve_dify_schema_refs( | |||
| # This avoids expensive deepcopy and BFS traversal for schemas without refs | |||
| if not _has_dify_refs(schema): | |||
| return schema | |||
| # Slow path: schema contains refs, perform full resolution | |||
| resolver = SchemaResolver(registry, max_depth) | |||
| return resolver.resolve(schema) | |||
| @@ -269,36 +257,36 @@ def resolve_dify_schema_refs( | |||
| def _remove_metadata_fields(schema: dict) -> dict: | |||
| """ | |||
| Remove metadata fields from schema that shouldn't be included in resolved output | |||
| Args: | |||
| schema: Schema dictionary | |||
| Returns: | |||
| Cleaned schema without metadata fields | |||
| """ | |||
| # Create a copy and remove metadata fields | |||
| cleaned = schema.copy() | |||
| metadata_fields = ["$id", "$schema", "version"] | |||
| for field in metadata_fields: | |||
| cleaned.pop(field, None) | |||
| return cleaned | |||
| def _is_dify_schema_ref(ref_uri: Any) -> bool: | |||
| """ | |||
| Check if the reference URI is a Dify schema reference | |||
| Args: | |||
| ref_uri: URI to check | |||
| Returns: | |||
| True if it's a Dify schema reference | |||
| """ | |||
| if not isinstance(ref_uri, str): | |||
| return False | |||
| # Use pre-compiled pattern for better performance | |||
| return bool(_DIFY_SCHEMA_PATTERN.match(ref_uri)) | |||
| @@ -306,12 +294,12 @@ def _is_dify_schema_ref(ref_uri: Any) -> bool: | |||
| def _has_dify_refs_recursive(schema: SchemaType) -> bool: | |||
| """ | |||
| Recursively check if a schema contains any Dify $ref references | |||
| This is the fallback method when string-based detection is not possible. | |||
| Args: | |||
| schema: Schema to check for references | |||
| Returns: | |||
| True if any Dify $ref is found, False otherwise | |||
| """ | |||
| @@ -320,18 +308,18 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool: | |||
| ref_uri = schema.get("$ref") | |||
| if ref_uri and _is_dify_schema_ref(ref_uri): | |||
| return True | |||
| # Check nested values | |||
| for value in schema.values(): | |||
| if _has_dify_refs_recursive(value): | |||
| return True | |||
| elif isinstance(schema, list): | |||
| # Check each item in the list | |||
| for item in schema: | |||
| if _has_dify_refs_recursive(item): | |||
| return True | |||
| # Primitive types don't contain refs | |||
| return False | |||
| @@ -339,36 +327,37 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool: | |||
| def _has_dify_refs_hybrid(schema: SchemaType) -> bool: | |||
| """ | |||
| Hybrid detection: fast string scan followed by precise recursive check | |||
| Performance optimization using two-phase detection: | |||
| 1. Fast string scan to quickly eliminate schemas without $ref | |||
| 2. Precise recursive validation only for potential candidates | |||
| Args: | |||
| schema: Schema to check for references | |||
| Returns: | |||
| True if any Dify $ref is found, False otherwise | |||
| """ | |||
| # Phase 1: Fast string-based pre-filtering | |||
| try: | |||
| import json | |||
| schema_str = json.dumps(schema, separators=(',', ':')) | |||
| schema_str = json.dumps(schema, separators=(",", ":")) | |||
| # Quick elimination: no $ref at all | |||
| if '"$ref"' not in schema_str: | |||
| return False | |||
| # Quick elimination: no Dify schema URLs | |||
| if 'https://dify.ai/schemas/' not in schema_str: | |||
| if "https://dify.ai/schemas/" not in schema_str: | |||
| return False | |||
| except (TypeError, ValueError, OverflowError): | |||
| # JSON serialization failed (e.g., circular references, non-serializable objects) | |||
| # Fall back to recursive detection | |||
| logger.debug("JSON serialization failed for schema, using recursive detection") | |||
| return _has_dify_refs_recursive(schema) | |||
| # Phase 2: Precise recursive validation | |||
| # Only executed for schemas that passed string pre-filtering | |||
| return _has_dify_refs_recursive(schema) | |||
| @@ -377,14 +366,14 @@ def _has_dify_refs_hybrid(schema: SchemaType) -> bool: | |||
| def _has_dify_refs(schema: SchemaType) -> bool: | |||
| """ | |||
| Check if a schema contains any Dify $ref references | |||
| Uses hybrid detection for optimal performance: | |||
| - Fast string scan for quick elimination | |||
| - Fast string scan for quick elimination | |||
| - Precise recursive check for validation | |||
| Args: | |||
| schema: Schema to check for references | |||
| Returns: | |||
| True if any Dify $ref is found, False otherwise | |||
| """ | |||
| @@ -394,15 +383,15 @@ def _has_dify_refs(schema: SchemaType) -> bool: | |||
| def parse_dify_schema_uri(uri: str) -> tuple[str, str]: | |||
| """ | |||
| Parse a Dify schema URI to extract version and schema name | |||
| Args: | |||
| uri: Schema URI to parse | |||
| Returns: | |||
| Tuple of (version, schema_name) or ("", "") if invalid | |||
| """ | |||
| match = _DIFY_SCHEMA_PATTERN.match(uri) | |||
| if not match: | |||
| return "", "" | |||
| return match.group(1), match.group(2) | |||
| return match.group(1), match.group(2) | |||
| @@ -13,10 +13,10 @@ class SchemaManager: | |||
| def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]: | |||
| """ | |||
| Get all JSON Schema definitions for a specific version | |||
| Args: | |||
| version: Schema version, defaults to v1 | |||
| Returns: | |||
| Array containing schema definitions, each element contains name and schema fields | |||
| """ | |||
| @@ -25,31 +25,28 @@ class SchemaManager: | |||
| def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]: | |||
| """ | |||
| Get a specific schema by name | |||
| Args: | |||
| schema_name: Schema name | |||
| version: Schema version, defaults to v1 | |||
| Returns: | |||
| Dictionary containing name and schema, returns None if not found | |||
| """ | |||
| uri = f"https://dify.ai/schemas/{version}/{schema_name}.json" | |||
| schema = self.registry.get_schema(uri) | |||
| if schema: | |||
| return { | |||
| "name": schema_name, | |||
| "schema": schema | |||
| } | |||
| return {"name": schema_name, "schema": schema} | |||
| return None | |||
| def list_available_schemas(self, version: str = "v1") -> list[str]: | |||
| """ | |||
| List all available schema names for a specific version | |||
| Args: | |||
| version: Schema version, defaults to v1 | |||
| Returns: | |||
| List of schema names | |||
| """ | |||
| @@ -58,8 +55,8 @@ class SchemaManager: | |||
| def list_available_versions(self) -> list[str]: | |||
| """ | |||
| List all available schema versions | |||
| Returns: | |||
| List of versions | |||
| """ | |||
| return self.registry.list_versions() | |||
| return self.registry.list_versions() | |||
| @@ -68,10 +68,10 @@ class VariablePool(BaseModel): | |||
| # Add rag pipeline variables to the variable pool | |||
| if self.rag_pipeline_variables: | |||
| rag_pipeline_variables_map = defaultdict(dict) | |||
| for var in self.rag_pipeline_variables: | |||
| node_id = var.variable.belong_to_node_id | |||
| key = var.variable.variable | |||
| value = var.value | |||
| for rag_var in self.rag_pipeline_variables: | |||
| node_id = rag_var.variable.belong_to_node_id | |||
| key = rag_var.variable.variable | |||
| value = rag_var.value | |||
| rag_pipeline_variables_map[node_id][key] = value | |||
| for key, value in rag_pipeline_variables_map.items(): | |||
| self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value) | |||
| @@ -37,12 +37,14 @@ class NodeType(StrEnum): | |||
| ANSWER = "answer" | |||
| LLM = "llm" | |||
| KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" | |||
| KNOWLEDGE_INDEX = "knowledge-index" | |||
| IF_ELSE = "if-else" | |||
| CODE = "code" | |||
| TEMPLATE_TRANSFORM = "template-transform" | |||
| QUESTION_CLASSIFIER = "question-classifier" | |||
| HTTP_REQUEST = "http-request" | |||
| TOOL = "tool" | |||
| DATASOURCE = "datasource" | |||
| VARIABLE_AGGREGATOR = "variable-aggregator" | |||
| LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. | |||
| LOOP = "loop" | |||
| @@ -83,6 +85,7 @@ class WorkflowType(StrEnum): | |||
| WORKFLOW = "workflow" | |||
| CHAT = "chat" | |||
| RAG_PIPELINE = "rag-pipeline" | |||
| class WorkflowExecutionStatus(StrEnum): | |||
| @@ -116,6 +119,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): | |||
| LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs | |||
| ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field | |||
| LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output | |||
| DATASOURCE_INFO = "datasource_info" | |||
| class WorkflowNodeExecutionStatus(StrEnum): | |||
| @@ -109,7 +109,7 @@ class Graph: | |||
| start_node_id = None | |||
| for nid in root_candidates: | |||
| node_data = node_configs_map[nid].get("data", {}) | |||
| if node_data.get("type") == NodeType.START.value: | |||
| if node_data.get("type") in [NodeType.START, NodeType.DATASOURCE]: | |||
| start_node_id = nid | |||
| break | |||
| @@ -19,16 +19,14 @@ from core.file.enums import FileTransferMethod, FileType | |||
| from core.plugin.impl.exc import PluginDaemonClientSideError | |||
| from core.variables.segments import ArrayAnySegment | |||
| from core.variables.variables import ArrayAnyVariable | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool, VariableValue | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey | |||
| from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent | |||
| from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig | |||
| from core.workflow.nodes.enums import ErrorStrategy, NodeType | |||
| from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent | |||
| from core.workflow.nodes.base.node import Node | |||
| from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser | |||
| from core.workflow.nodes.tool.exc import ToolFileError | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models.model import UploadFile | |||
| @@ -39,7 +37,7 @@ from .entities import DatasourceNodeData | |||
| from .exc import DatasourceNodeError, DatasourceParameterError | |||
| class DatasourceNode(BaseNode): | |||
| class DatasourceNode(Node): | |||
| """ | |||
| Datasource Node | |||
| """ | |||
| @@ -97,8 +95,8 @@ class DatasourceNode(BaseNode): | |||
| datasource_type=DatasourceProviderType.value_of(datasource_type), | |||
| ) | |||
| except DatasourceNodeError as e: | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs={}, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| @@ -172,8 +170,8 @@ class DatasourceNode(BaseNode): | |||
| datasource_type=datasource_type, | |||
| ) | |||
| case DatasourceProviderType.WEBSITE_CRAWL: | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=parameters_for_log, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| @@ -204,10 +202,10 @@ class DatasourceNode(BaseNode): | |||
| size=upload_file.size, | |||
| storage_key=upload_file.key, | |||
| ) | |||
| variable_pool.add([self.node_id, "file"], file_info) | |||
| variable_pool.add([self._node_id, "file"], file_info) | |||
| # variable_pool.add([self.node_id, "file"], file_info.to_dict()) | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=parameters_for_log, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| @@ -220,8 +218,8 @@ class DatasourceNode(BaseNode): | |||
| case _: | |||
| raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") | |||
| except PluginDaemonClientSideError as e: | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=parameters_for_log, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| @@ -230,8 +228,8 @@ class DatasourceNode(BaseNode): | |||
| ) | |||
| ) | |||
| except DatasourceNodeError as e: | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=parameters_for_log, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| @@ -425,8 +423,10 @@ class DatasourceNode(BaseNode): | |||
| elif message.type == DatasourceMessage.MessageType.TEXT: | |||
| assert isinstance(message.message, DatasourceMessage.TextMessage) | |||
| text += message.message.text | |||
| yield RunStreamChunkEvent( | |||
| chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] | |||
| yield StreamChunkEvent( | |||
| selector=[self._node_id, "text"], | |||
| chunk=message.message.text, | |||
| is_final=False, | |||
| ) | |||
| elif message.type == DatasourceMessage.MessageType.JSON: | |||
| assert isinstance(message.message, DatasourceMessage.JsonMessage) | |||
| @@ -442,7 +442,11 @@ class DatasourceNode(BaseNode): | |||
| assert isinstance(message.message, DatasourceMessage.TextMessage) | |||
| stream_text = f"Link: {message.message.text}\n" | |||
| text += stream_text | |||
| yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) | |||
| yield StreamChunkEvent( | |||
| selector=[self._node_id, "text"], | |||
| chunk=stream_text, | |||
| is_final=False, | |||
| ) | |||
| elif message.type == DatasourceMessage.MessageType.VARIABLE: | |||
| assert isinstance(message.message, DatasourceMessage.VariableMessage) | |||
| variable_name = message.message.variable_name | |||
| @@ -454,17 +458,24 @@ class DatasourceNode(BaseNode): | |||
| variables[variable_name] = "" | |||
| variables[variable_name] += variable_value | |||
| yield RunStreamChunkEvent( | |||
| chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] | |||
| yield StreamChunkEvent( | |||
| selector=[self._node_id, variable_name], | |||
| chunk=variable_value, | |||
| is_final=False, | |||
| ) | |||
| else: | |||
| variables[variable_name] = variable_value | |||
| elif message.type == DatasourceMessage.MessageType.FILE: | |||
| assert message.meta is not None | |||
| files.append(message.meta["file"]) | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| # mark the end of the stream | |||
| yield StreamChunkEvent( | |||
| selector=[self._node_id, "text"], | |||
| chunk="", | |||
| is_final=True, | |||
| ) | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| outputs={"json": json, "files": files, **variables, "text": text}, | |||
| metadata={ | |||
| @@ -526,9 +537,9 @@ class DatasourceNode(BaseNode): | |||
| tenant_id=self.tenant_id, | |||
| ) | |||
| if file: | |||
| variable_pool.add([self.node_id, "file"], file) | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| variable_pool.add([self._node_id, "file"], file) | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=parameters_for_log, | |||
| metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| @@ -9,16 +9,15 @@ from sqlalchemy import func | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey | |||
| from core.workflow.node_events import NodeRunResult | |||
| from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig | |||
| from core.workflow.nodes.enums import ErrorStrategy, NodeType | |||
| from core.workflow.nodes.base.node import Node | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from ..base import BaseNode | |||
| from .entities import KnowledgeIndexNodeData | |||
| from .exc import ( | |||
| KnowledgeIndexNodeError, | |||
| @@ -35,7 +34,7 @@ default_retrieval_model = { | |||
| } | |||
| class KnowledgeIndexNode(BaseNode): | |||
| class KnowledgeIndexNode(Node): | |||
| _node_data: KnowledgeIndexNodeData | |||
| _node_type = NodeType.KNOWLEDGE_INDEX | |||
| @@ -93,15 +92,12 @@ class KnowledgeIndexNode(BaseNode): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=variables, | |||
| process_data=None, | |||
| outputs=outputs, | |||
| ) | |||
| results = self._invoke_knowledge_index( | |||
| dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool | |||
| ) | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results | |||
| ) | |||
| return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results) | |||
| except KnowledgeIndexNodeError as e: | |||
| logger.warning("Error when running knowledge index node") | |||
| @@ -172,7 +172,7 @@ class Dataset(Base): | |||
| ) | |||
| @property | |||
| def doc_form(self): | |||
| def doc_form(self) -> Optional[str]: | |||
| if self.chunk_structure: | |||
| return self.chunk_structure | |||
| document = db.session.query(Document).filter(Document.dataset_id == self.id).first() | |||
| @@ -424,7 +424,7 @@ class Document(Base): | |||
| return status | |||
| @property | |||
| def data_source_info_dict(self): | |||
| def data_source_info_dict(self) -> dict[str, Any]: | |||
| if self.data_source_info: | |||
| try: | |||
| data_source_info_dict = json.loads(self.data_source_info) | |||
| @@ -432,7 +432,7 @@ class Document(Base): | |||
| data_source_info_dict = {} | |||
| return data_source_info_dict | |||
| return None | |||
| return {} | |||
| @property | |||
| def data_source_detail_dict(self): | |||
| @@ -52,3 +52,8 @@ class ToolProviderID(GenericProviderID): | |||
| if self.organization == "langgenius": | |||
| if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]: | |||
| self.plugin_name = f"{self.provider_name}_tool" | |||
| class DatasourceProviderID(GenericProviderID): | |||
| def __init__(self, value: str, is_hardcoded: bool = False) -> None: | |||
| super().__init__(value, is_hardcoded) | |||
| @@ -718,9 +718,9 @@ class DatasetService: | |||
| model_manager = ModelManager() | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=knowledge_configuration.embedding_model_provider, | |||
| provider=knowledge_configuration.embedding_model_provider or "", | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=knowledge_configuration.embedding_model, | |||
| model=knowledge_configuration.embedding_model or "", | |||
| ) | |||
| dataset.embedding_model = embedding_model.model | |||
| dataset.embedding_model_provider = embedding_model.provider | |||
| @@ -1159,7 +1159,7 @@ class DocumentService: | |||
| return | |||
| documents = db.session.query(Document).where(Document.id.in_(document_ids)).all() | |||
| file_ids = [ | |||
| document.data_source_info_dict["upload_file_id"] | |||
| document.data_source_info_dict.get("upload_file_id", "") | |||
| for document in documents | |||
| if document.data_source_type == "upload_file" | |||
| ] | |||
| @@ -1281,7 +1281,7 @@ class DocumentService: | |||
| account: Account | Any, | |||
| dataset_process_rule: Optional[DatasetProcessRule] = None, | |||
| created_from: str = "web", | |||
| ): | |||
| ) -> tuple[list[Document], str]: | |||
| # check doc_form | |||
| DatasetService.check_doc_form(dataset, knowledge_config.doc_form) | |||
| # check document limit | |||
| @@ -1386,7 +1386,7 @@ class DocumentService: | |||
| "Invalid process rule mode: %s, can not find dataset process rule", | |||
| process_rule.mode, | |||
| ) | |||
| return | |||
| return [], "" | |||
| db.session.add(dataset_process_rule) | |||
| db.session.flush() | |||
| lock_name = f"add_document_lock_dataset_id_{dataset.id}" | |||
| @@ -2595,7 +2595,9 @@ class SegmentService: | |||
| return segment_data_list | |||
| @classmethod | |||
| def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): | |||
| def update_segment( | |||
| cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset | |||
| ) -> DocumentSegment: | |||
| indexing_cache_key = f"segment_{segment.id}_indexing" | |||
| cache_result = redis_client.get(indexing_cache_key) | |||
| if cache_result is not None: | |||
| @@ -2764,6 +2766,8 @@ class SegmentService: | |||
| segment.error = str(e) | |||
| db.session.commit() | |||
| new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() | |||
| if not new_segment: | |||
| raise ValueError("new_segment is not found") | |||
| return new_segment | |||
| @classmethod | |||
| @@ -2804,7 +2808,11 @@ class SegmentService: | |||
| index_node_ids = [seg.index_node_id for seg in segments] | |||
| total_words = sum(seg.word_count for seg in segments) | |||
| document.word_count -= total_words | |||
| if document.word_count is None: | |||
| document.word_count = 0 | |||
| else: | |||
| document.word_count = max(0, document.word_count - total_words) | |||
| db.session.add(document) | |||
| delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) | |||
| @@ -11,7 +11,6 @@ from core.helper import encrypter | |||
| from core.helper.name_generator import generate_incremental_name | |||
| from core.helper.provider_cache import NoOpProviderCredentialCache | |||
| from core.model_runtime.entities.provider_entities import FormType | |||
| from core.plugin.entities.plugin import DatasourceProviderID | |||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||
| from core.plugin.impl.oauth import OAuthHandler | |||
| from core.tools.entities.tool_entities import CredentialType | |||
| @@ -19,6 +18,7 @@ from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncry | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider | |||
| from models.provider_ids import DatasourceProviderID | |||
| from services.plugin.plugin_service import PluginService | |||
| logger = logging.getLogger(__name__) | |||
| @@ -809,9 +809,7 @@ class DatasourceProviderService: | |||
| credentials = self.list_datasource_credentials( | |||
| tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id | |||
| ) | |||
| redirect_uri = ( | |||
| f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" | |||
| ) | |||
| redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" | |||
| datasource_credentials.append( | |||
| { | |||
| "provider": datasource.provider, | |||
| @@ -1,6 +1,6 @@ | |||
| from typing import Literal, Optional | |||
| from pydantic import BaseModel | |||
| from pydantic import BaseModel, field_validator | |||
| class IconInfo(BaseModel): | |||
| @@ -110,7 +110,21 @@ class KnowledgeConfiguration(BaseModel): | |||
| chunk_structure: str | |||
| indexing_technique: Literal["high_quality", "economy"] | |||
| embedding_model_provider: Optional[str] = "" | |||
| embedding_model: Optional[str] = "" | |||
| embedding_model_provider: str = "" | |||
| embedding_model: str = "" | |||
| keyword_number: Optional[int] = 10 | |||
| retrieval_model: RetrievalSetting | |||
| @field_validator("embedding_model_provider", mode="before") | |||
| @classmethod | |||
| def validate_embedding_model_provider(cls, v): | |||
| if v is None: | |||
| return "" | |||
| return v | |||
| @field_validator("embedding_model", mode="before") | |||
| @classmethod | |||
| def validate_embedding_model(cls, v): | |||
| if v is None: | |||
| return "" | |||
| return v | |||
| @@ -28,26 +28,23 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen | |||
| from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin | |||
| from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin | |||
| from core.rag.entities.event import ( | |||
| BaseDatasourceEvent, | |||
| DatasourceCompletedEvent, | |||
| DatasourceErrorEvent, | |||
| DatasourceProcessingEvent, | |||
| ) | |||
| from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.variables.variables import Variable | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import ( | |||
| WorkflowNodeExecution, | |||
| WorkflowNodeExecutionStatus, | |||
| ) | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey | |||
| from core.workflow.errors import WorkflowNodeRunFailedError | |||
| from core.workflow.graph_engine.entities.event import InNodeEvent | |||
| from core.workflow.nodes.base.node import BaseNode | |||
| from core.workflow.nodes.enums import ErrorStrategy, NodeType | |||
| from core.workflow.nodes.event.event import RunCompletedEvent | |||
| from core.workflow.nodes.event.types import NodeEvent | |||
| from core.workflow.graph_events.base import GraphNodeEventBase | |||
| from core.workflow.node_events.base import NodeRunResult | |||
| from core.workflow.node_events.node import StreamCompletedEvent | |||
| from core.workflow.nodes.base.node import Node | |||
| from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING | |||
| from core.workflow.repositories.workflow_node_execution_repository import OrderConfig | |||
| from core.workflow.system_variable import SystemVariable | |||
| @@ -105,12 +102,13 @@ class RagPipelineService: | |||
| if type == "built-in": | |||
| mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE | |||
| retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() | |||
| result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) | |||
| built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) | |||
| return built_in_result | |||
| else: | |||
| mode = "customized" | |||
| retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() | |||
| result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) | |||
| return result | |||
| customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) | |||
| return customized_result | |||
| @classmethod | |||
| def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity): | |||
| @@ -471,7 +469,7 @@ class RagPipelineService: | |||
| datasource_type: str, | |||
| is_published: bool, | |||
| credential_id: Optional[str] = None, | |||
| ) -> Generator[BaseDatasourceEvent, None, None]: | |||
| ) -> Generator[Mapping[str, Any], None, None]: | |||
| """ | |||
| Run published workflow datasource | |||
| """ | |||
| @@ -563,9 +561,9 @@ class RagPipelineService: | |||
| user_id=account.id, | |||
| request=OnlineDriveBrowseFilesRequest( | |||
| bucket=user_inputs.get("bucket"), | |||
| prefix=user_inputs.get("prefix"), | |||
| prefix=user_inputs.get("prefix", ""), | |||
| max_keys=user_inputs.get("max_keys", 20), | |||
| start_after=user_inputs.get("start_after"), | |||
| next_page_parameters=user_inputs.get("next_page_parameters"), | |||
| ), | |||
| provider_type=datasource_runtime.datasource_provider_type(), | |||
| ) | |||
| @@ -600,7 +598,7 @@ class RagPipelineService: | |||
| end_time = time.time() | |||
| if message.result.status == "completed": | |||
| crawl_event = DatasourceCompletedEvent( | |||
| data=message.result.web_info_list, | |||
| data=message.result.web_info_list or [], | |||
| total=message.result.total, | |||
| completed=message.result.completed, | |||
| time_consuming=round(end_time - start_time, 2), | |||
| @@ -681,9 +679,9 @@ class RagPipelineService: | |||
| datasource_runtime.get_online_document_page_content( | |||
| user_id=account.id, | |||
| datasource_parameters=GetOnlineDocumentPageContentRequest( | |||
| workspace_id=user_inputs.get("workspace_id"), | |||
| page_id=user_inputs.get("page_id"), | |||
| type=user_inputs.get("type"), | |||
| workspace_id=user_inputs.get("workspace_id", ""), | |||
| page_id=user_inputs.get("page_id", ""), | |||
| type=user_inputs.get("type", ""), | |||
| ), | |||
| provider_type=datasource_type, | |||
| ) | |||
| @@ -740,7 +738,7 @@ class RagPipelineService: | |||
| def _handle_node_run_result( | |||
| self, | |||
| getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]], | |||
| getter: Callable[[], tuple[Node, Generator[GraphNodeEventBase, None, None]]], | |||
| start_at: float, | |||
| tenant_id: str, | |||
| node_id: str, | |||
| @@ -758,17 +756,16 @@ class RagPipelineService: | |||
| node_run_result: NodeRunResult | None = None | |||
| for event in generator: | |||
| if isinstance(event, RunCompletedEvent): | |||
| node_run_result = event.run_result | |||
| if isinstance(event, StreamCompletedEvent): | |||
| node_run_result = event.node_run_result | |||
| # sign output files | |||
| node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) | |||
| node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {} | |||
| break | |||
| if not node_run_result: | |||
| raise ValueError("Node run failed with no run result") | |||
| # single step debug mode error handling return | |||
| if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.continue_on_error: | |||
| if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.error_strategy: | |||
| node_error_args: dict[str, Any] = { | |||
| "status": WorkflowNodeExecutionStatus.EXCEPTION, | |||
| "error": node_run_result.error, | |||
| @@ -808,7 +805,7 @@ class RagPipelineService: | |||
| workflow_id=node_instance.workflow_id, | |||
| index=1, | |||
| node_id=node_id, | |||
| node_type=node_instance.type_, | |||
| node_type=node_instance.node_type, | |||
| title=node_instance.title, | |||
| elapsed_time=time.perf_counter() - start_at, | |||
| finished_at=datetime.now(UTC).replace(tzinfo=None), | |||
| @@ -1148,7 +1145,7 @@ class RagPipelineService: | |||
| .first() | |||
| ) | |||
| return node_exec | |||
| def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account | EndUser): | |||
| # fetch draft workflow by app_model | |||
| draft_workflow = self.get_draft_workflow(pipeline=pipeline) | |||
| @@ -1208,6 +1205,3 @@ class RagPipelineService: | |||
| ) | |||
| session.commit() | |||
| return workflow_node_execution_db_model | |||
| @@ -23,8 +23,8 @@ from core.helper import ssrf_proxy | |||
| from core.helper.name_generator import generate_incremental_name | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.plugin.entities.plugin import PluginDependency | |||
| from core.workflow.enums import NodeType | |||
| from core.workflow.nodes.datasource.entities import DatasourceNodeData | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData | |||
| from core.workflow.nodes.llm.entities import LLMNodeData | |||
| from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData | |||
| @@ -281,7 +281,7 @@ class RagPipelineDslService: | |||
| icon = icon_info.icon | |||
| icon_background = icon_info.icon_background | |||
| icon_url = icon_info.icon_url | |||
| else: | |||
| else: | |||
| icon_type = data.get("rag_pipeline", {}).get("icon_type") | |||
| icon = data.get("rag_pipeline", {}).get("icon") | |||
| icon_background = data.get("rag_pipeline", {}).get("icon_background") | |||
| @@ -1,6 +1,7 @@ | |||
| import json | |||
| from datetime import UTC, datetime | |||
| from pathlib import Path | |||
| from typing import Optional | |||
| from uuid import uuid4 | |||
| import yaml | |||
| @@ -87,7 +88,7 @@ class RagPipelineTransformService: | |||
| "status": "success", | |||
| } | |||
| def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str): | |||
| def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: Optional[str]): | |||
| if doc_form == "text_model": | |||
| match datasource_type: | |||
| case "upload_file": | |||
| @@ -148,7 +149,7 @@ class RagPipelineTransformService: | |||
| return node | |||
| def _deal_knowledge_index( | |||
| self, dataset: Dataset, doc_form: str, indexing_technique: str, retrieval_model: dict, node: dict | |||
| self, dataset: Dataset, doc_form: str, indexing_technique: Optional[str], retrieval_model: dict, node: dict | |||
| ): | |||
| knowledge_configuration_dict = node.get("data", {}) | |||
| knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict) | |||
| @@ -1,5 +1,6 @@ | |||
| import logging | |||
| import time | |||
| from typing import Optional | |||
| import click | |||
| from celery import shared_task | |||
| @@ -15,7 +16,7 @@ logger = logging.getLogger(__name__) | |||
| @shared_task(queue="dataset") | |||
| def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]): | |||
| def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: Optional[str], file_ids: list[str]): | |||
| """ | |||
| Clean document when document deleted. | |||
| :param document_ids: document ids | |||
| @@ -29,6 +30,8 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form | |||
| start_at = time.perf_counter() | |||
| try: | |||
| if not doc_form: | |||
| raise ValueError("doc_form is required") | |||
| dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| @@ -21,14 +21,16 @@ from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| @shared_task(queue="dataset") | |||
| def rag_pipeline_run_task(pipeline_id: str, | |||
| application_generate_entity: dict, | |||
| user_id: str, | |||
| tenant_id: str, | |||
| workflow_id: str, | |||
| streaming: bool, | |||
| workflow_execution_id: str | None = None, | |||
| workflow_thread_pool_id: str | None = None): | |||
| def rag_pipeline_run_task( | |||
| pipeline_id: str, | |||
| application_generate_entity: dict, | |||
| user_id: str, | |||
| tenant_id: str, | |||
| workflow_id: str, | |||
| streaming: bool, | |||
| workflow_execution_id: str | None = None, | |||
| workflow_thread_pool_id: str | None = None, | |||
| ): | |||
| """ | |||
| Async Run rag pipeline | |||
| :param pipeline_id: Pipeline ID | |||
| @@ -94,18 +96,19 @@ def rag_pipeline_run_task(pipeline_id: str, | |||
| with current_app.app_context(): | |||
| # Set the user directly in g for preserve_flask_contexts | |||
| g._login_user = account | |||
| # Copy context for thread (after setting user) | |||
| context = contextvars.copy_context() | |||
| # Get Flask app object in the main thread where app context exists | |||
| flask_app = current_app._get_current_object() # type: ignore | |||
| # Create a wrapper function that passes user context | |||
| def _run_with_user_context(): | |||
| # Don't create a new app context here - let _generate handle it | |||
| # Just ensure the user is available in contextvars | |||
| from core.app.apps.pipeline.pipeline_generator import PipelineGenerator | |||
| pipeline_generator = PipelineGenerator() | |||
| pipeline_generator._generate( | |||
| flask_app=flask_app, | |||
| @@ -120,7 +123,7 @@ def rag_pipeline_run_task(pipeline_id: str, | |||
| streaming=streaming, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| ) | |||
| # Create and start worker thread | |||
| worker_thread = threading.Thread(target=_run_with_user_context) | |||
| worker_thread.start() | |||
| @@ -1 +1 @@ | |||
| # Core schemas unit tests | |||
| # Core schemas unit tests | |||
| @@ -33,18 +33,16 @@ class TestSchemaResolver: | |||
| def test_simple_ref_resolution(self): | |||
| """Test resolving a simple $ref to a complete schema""" | |||
| schema_with_ref = { | |||
| "$ref": "https://dify.ai/schemas/v1/qa_structure.json" | |||
| } | |||
| schema_with_ref = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"} | |||
| resolved = resolve_dify_schema_refs(schema_with_ref) | |||
| # Should be resolved to the actual qa_structure schema | |||
| assert resolved["type"] == "object" | |||
| assert resolved["title"] == "Q&A Structure Schema" | |||
| assert "qa_chunks" in resolved["properties"] | |||
| assert resolved["properties"]["qa_chunks"]["type"] == "array" | |||
| # Metadata fields should be removed | |||
| assert "$id" not in resolved | |||
| assert "$schema" not in resolved | |||
| @@ -55,29 +53,24 @@ class TestSchemaResolver: | |||
| nested_schema = { | |||
| "type": "object", | |||
| "properties": { | |||
| "file_data": { | |||
| "$ref": "https://dify.ai/schemas/v1/file.json" | |||
| }, | |||
| "metadata": { | |||
| "type": "string", | |||
| "description": "Additional metadata" | |||
| } | |||
| } | |||
| "file_data": {"$ref": "https://dify.ai/schemas/v1/file.json"}, | |||
| "metadata": {"type": "string", "description": "Additional metadata"}, | |||
| }, | |||
| } | |||
| resolved = resolve_dify_schema_refs(nested_schema) | |||
| # Original structure should be preserved | |||
| assert resolved["type"] == "object" | |||
| assert "metadata" in resolved["properties"] | |||
| assert resolved["properties"]["metadata"]["type"] == "string" | |||
| # $ref should be resolved | |||
| file_schema = resolved["properties"]["file_data"] | |||
| assert file_schema["type"] == "object" | |||
| assert file_schema["title"] == "File Schema" | |||
| assert "name" in file_schema["properties"] | |||
| # Metadata fields should be removed from resolved schema | |||
| assert "$id" not in file_schema | |||
| assert "$schema" not in file_schema | |||
| @@ -87,18 +80,16 @@ class TestSchemaResolver: | |||
| """Test resolving $refs in array items""" | |||
| array_schema = { | |||
| "type": "array", | |||
| "items": { | |||
| "$ref": "https://dify.ai/schemas/v1/general_structure.json" | |||
| }, | |||
| "description": "Array of general structures" | |||
| "items": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"}, | |||
| "description": "Array of general structures", | |||
| } | |||
| resolved = resolve_dify_schema_refs(array_schema) | |||
| # Array structure should be preserved | |||
| assert resolved["type"] == "array" | |||
| assert resolved["description"] == "Array of general structures" | |||
| # Items $ref should be resolved | |||
| items_schema = resolved["items"] | |||
| assert items_schema["type"] == "array" | |||
| @@ -109,20 +100,16 @@ class TestSchemaResolver: | |||
| external_ref_schema = { | |||
| "type": "object", | |||
| "properties": { | |||
| "external_data": { | |||
| "$ref": "https://example.com/external-schema.json" | |||
| }, | |||
| "dify_data": { | |||
| "$ref": "https://dify.ai/schemas/v1/file.json" | |||
| } | |||
| } | |||
| "external_data": {"$ref": "https://example.com/external-schema.json"}, | |||
| "dify_data": {"$ref": "https://dify.ai/schemas/v1/file.json"}, | |||
| }, | |||
| } | |||
| resolved = resolve_dify_schema_refs(external_ref_schema) | |||
| # External $ref should remain unchanged | |||
| assert resolved["properties"]["external_data"]["$ref"] == "https://example.com/external-schema.json" | |||
| # Dify $ref should be resolved | |||
| assert resolved["properties"]["dify_data"]["type"] == "object" | |||
| assert resolved["properties"]["dify_data"]["title"] == "File Schema" | |||
| @@ -132,22 +119,14 @@ class TestSchemaResolver: | |||
| simple_schema = { | |||
| "type": "object", | |||
| "properties": { | |||
| "name": { | |||
| "type": "string", | |||
| "description": "Name field" | |||
| }, | |||
| "items": { | |||
| "type": "array", | |||
| "items": { | |||
| "type": "number" | |||
| } | |||
| } | |||
| "name": {"type": "string", "description": "Name field"}, | |||
| "items": {"type": "array", "items": {"type": "number"}}, | |||
| }, | |||
| "required": ["name"] | |||
| "required": ["name"], | |||
| } | |||
| resolved = resolve_dify_schema_refs(simple_schema) | |||
| # Should be identical to input | |||
| assert resolved == simple_schema | |||
| assert resolved["type"] == "object" | |||
| @@ -159,21 +138,16 @@ class TestSchemaResolver: | |||
| """Test that excessive recursion depth is prevented""" | |||
| # Create a moderately nested structure | |||
| deep_schema = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"} | |||
| # Wrap it in fewer layers to make the test more reasonable | |||
| for _ in range(2): | |||
| deep_schema = { | |||
| "type": "object", | |||
| "properties": { | |||
| "nested": deep_schema | |||
| } | |||
| } | |||
| deep_schema = {"type": "object", "properties": {"nested": deep_schema}} | |||
| # Should handle normal cases fine with reasonable depth | |||
| resolved = resolve_dify_schema_refs(deep_schema, max_depth=25) | |||
| assert resolved is not None | |||
| assert resolved["type"] == "object" | |||
| # Should raise error with very low max_depth | |||
| with pytest.raises(MaxDepthExceededError) as exc_info: | |||
| resolve_dify_schema_refs(deep_schema, max_depth=5) | |||
| @@ -185,12 +159,12 @@ class TestSchemaResolver: | |||
| mock_registry = MagicMock() | |||
| mock_registry.get_schema.side_effect = lambda uri: { | |||
| "$ref": "https://dify.ai/schemas/v1/circular.json", | |||
| "type": "object" | |||
| "type": "object", | |||
| } | |||
| schema = {"$ref": "https://dify.ai/schemas/v1/circular.json"} | |||
| resolved = resolve_dify_schema_refs(schema, registry=mock_registry) | |||
| # Should mark circular reference | |||
| assert "$circular_ref" in resolved | |||
| @@ -199,10 +173,10 @@ class TestSchemaResolver: | |||
| # Mock registry that returns None for unknown schemas | |||
| mock_registry = MagicMock() | |||
| mock_registry.get_schema.return_value = None | |||
| schema = {"$ref": "https://dify.ai/schemas/v1/unknown.json"} | |||
| resolved = resolve_dify_schema_refs(schema, registry=mock_registry) | |||
| # Should keep the original $ref when schema not found | |||
| assert resolved["$ref"] == "https://dify.ai/schemas/v1/unknown.json" | |||
| @@ -217,25 +191,25 @@ class TestSchemaResolver: | |||
| def test_cache_functionality(self): | |||
| """Test that caching works correctly""" | |||
| schema = {"$ref": "https://dify.ai/schemas/v1/file.json"} | |||
| # First resolution should fetch from registry | |||
| resolved1 = resolve_dify_schema_refs(schema) | |||
| # Mock the registry to return different data | |||
| with patch.object(self.registry, "get_schema") as mock_get: | |||
| mock_get.return_value = {"type": "different"} | |||
| # Second resolution should use cache | |||
| resolved2 = resolve_dify_schema_refs(schema) | |||
| # Should be the same as first resolution (from cache) | |||
| assert resolved1 == resolved2 | |||
| # Mock should not have been called | |||
| mock_get.assert_not_called() | |||
| # Clear cache and try again | |||
| SchemaResolver.clear_cache() | |||
| # Now it should fetch again | |||
| resolved3 = resolve_dify_schema_refs(schema) | |||
| assert resolved3 == resolved1 | |||
| @@ -244,14 +218,11 @@ class TestSchemaResolver: | |||
| """Test that the resolver is thread-safe""" | |||
| schema = { | |||
| "type": "object", | |||
| "properties": { | |||
| f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} | |||
| for i in range(10) | |||
| } | |||
| "properties": {f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} for i in range(10)}, | |||
| } | |||
| results = [] | |||
| def resolve_in_thread(): | |||
| try: | |||
| result = resolve_dify_schema_refs(schema) | |||
| @@ -260,12 +231,12 @@ class TestSchemaResolver: | |||
| except Exception as e: | |||
| results.append(e) | |||
| return False | |||
| # Run multiple threads concurrently | |||
| with ThreadPoolExecutor(max_workers=10) as executor: | |||
| futures = [executor.submit(resolve_in_thread) for _ in range(20)] | |||
| success = all(f.result() for f in futures) | |||
| assert success | |||
| # All results should be the same | |||
| first_result = results[0] | |||
| @@ -276,10 +247,7 @@ class TestSchemaResolver: | |||
| complex_schema = { | |||
| "type": "object", | |||
| "properties": { | |||
| "files": { | |||
| "type": "array", | |||
| "items": {"$ref": "https://dify.ai/schemas/v1/file.json"} | |||
| }, | |||
| "files": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}}, | |||
| "nested": { | |||
| "type": "object", | |||
| "properties": { | |||
| @@ -290,21 +258,21 @@ class TestSchemaResolver: | |||
| "type": "object", | |||
| "properties": { | |||
| "general": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"} | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }, | |||
| }, | |||
| }, | |||
| }, | |||
| }, | |||
| }, | |||
| } | |||
| resolved = resolve_dify_schema_refs(complex_schema, max_depth=20) | |||
| # Check structure is preserved | |||
| assert resolved["type"] == "object" | |||
| assert "files" in resolved["properties"] | |||
| assert "nested" in resolved["properties"] | |||
| # Check refs are resolved | |||
| assert resolved["properties"]["files"]["items"]["type"] == "object" | |||
| assert resolved["properties"]["files"]["items"]["title"] == "File Schema" | |||
| @@ -314,14 +282,14 @@ class TestSchemaResolver: | |||
| class TestUtilityFunctions: | |||
| """Test utility functions""" | |||
| def test_is_dify_schema_ref(self): | |||
| """Test _is_dify_schema_ref function""" | |||
| # Valid Dify refs | |||
| assert _is_dify_schema_ref("https://dify.ai/schemas/v1/file.json") | |||
| assert _is_dify_schema_ref("https://dify.ai/schemas/v2/complex_name.json") | |||
| assert _is_dify_schema_ref("https://dify.ai/schemas/v999/test-file.json") | |||
| # Invalid refs | |||
| assert not _is_dify_schema_ref("https://example.com/schema.json") | |||
| assert not _is_dify_schema_ref("https://dify.ai/other/path.json") | |||
| @@ -330,61 +298,46 @@ class TestUtilityFunctions: | |||
| assert not _is_dify_schema_ref(None) | |||
| assert not _is_dify_schema_ref(123) | |||
| assert not _is_dify_schema_ref(["list"]) | |||
| def test_has_dify_refs(self): | |||
| """Test _has_dify_refs function""" | |||
| # Schemas with Dify refs | |||
| assert _has_dify_refs({"$ref": "https://dify.ai/schemas/v1/file.json"}) | |||
| assert _has_dify_refs({ | |||
| "type": "object", | |||
| "properties": { | |||
| "data": {"$ref": "https://dify.ai/schemas/v1/file.json"} | |||
| } | |||
| }) | |||
| assert _has_dify_refs([ | |||
| {"type": "string"}, | |||
| {"$ref": "https://dify.ai/schemas/v1/file.json"} | |||
| ]) | |||
| assert _has_dify_refs({ | |||
| "type": "array", | |||
| "items": { | |||
| "type": "object", | |||
| "properties": { | |||
| "nested": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"} | |||
| } | |||
| assert _has_dify_refs( | |||
| {"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}} | |||
| ) | |||
| assert _has_dify_refs([{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/file.json"}]) | |||
| assert _has_dify_refs( | |||
| { | |||
| "type": "array", | |||
| "items": { | |||
| "type": "object", | |||
| "properties": {"nested": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}}, | |||
| }, | |||
| } | |||
| }) | |||
| ) | |||
| # Schemas without Dify refs | |||
| assert not _has_dify_refs({"type": "string"}) | |||
| assert not _has_dify_refs({ | |||
| "type": "object", | |||
| "properties": { | |||
| "name": {"type": "string"}, | |||
| "age": {"type": "number"} | |||
| } | |||
| }) | |||
| assert not _has_dify_refs([ | |||
| {"type": "string"}, | |||
| {"type": "number"}, | |||
| {"type": "object", "properties": {"name": {"type": "string"}}} | |||
| ]) | |||
| assert not _has_dify_refs( | |||
| {"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}} | |||
| ) | |||
| assert not _has_dify_refs( | |||
| [{"type": "string"}, {"type": "number"}, {"type": "object", "properties": {"name": {"type": "string"}}}] | |||
| ) | |||
| # Schemas with non-Dify refs (should return False) | |||
| assert not _has_dify_refs({"$ref": "https://example.com/schema.json"}) | |||
| assert not _has_dify_refs({ | |||
| "type": "object", | |||
| "properties": { | |||
| "external": {"$ref": "https://example.com/external.json"} | |||
| } | |||
| }) | |||
| assert not _has_dify_refs( | |||
| {"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}} | |||
| ) | |||
| # Primitive types | |||
| assert not _has_dify_refs("string") | |||
| assert not _has_dify_refs(123) | |||
| assert not _has_dify_refs(True) | |||
| assert not _has_dify_refs(None) | |||
| def test_has_dify_refs_hybrid_vs_recursive(self): | |||
| """Test that hybrid and recursive detection give same results""" | |||
| test_schemas = [ | |||
| @@ -392,29 +345,13 @@ class TestUtilityFunctions: | |||
| {"type": "string"}, | |||
| {"type": "object", "properties": {"name": {"type": "string"}}}, | |||
| [{"type": "string"}, {"type": "number"}], | |||
| # With Dify refs | |||
| # With Dify refs | |||
| {"$ref": "https://dify.ai/schemas/v1/file.json"}, | |||
| { | |||
| "type": "object", | |||
| "properties": { | |||
| "data": {"$ref": "https://dify.ai/schemas/v1/file.json"} | |||
| } | |||
| }, | |||
| [ | |||
| {"type": "string"}, | |||
| {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"} | |||
| ], | |||
| {"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}}, | |||
| [{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}], | |||
| # With non-Dify refs | |||
| {"$ref": "https://example.com/schema.json"}, | |||
| { | |||
| "type": "object", | |||
| "properties": { | |||
| "external": {"$ref": "https://example.com/external.json"} | |||
| } | |||
| }, | |||
| {"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}}, | |||
| # Complex nested | |||
| { | |||
| "type": "object", | |||
| @@ -422,41 +359,40 @@ class TestUtilityFunctions: | |||
| "level1": { | |||
| "type": "object", | |||
| "properties": { | |||
| "level2": { | |||
| "type": "array", | |||
| "items": {"$ref": "https://dify.ai/schemas/v1/file.json"} | |||
| } | |||
| } | |||
| "level2": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}} | |||
| }, | |||
| } | |||
| } | |||
| }, | |||
| }, | |||
| # Edge cases | |||
| {"description": "This mentions $ref but is not a reference"}, | |||
| {"$ref": "not-a-url"}, | |||
| # Primitive types | |||
| "string", 123, True, None, [] | |||
| "string", | |||
| 123, | |||
| True, | |||
| None, | |||
| [], | |||
| ] | |||
| for schema in test_schemas: | |||
| hybrid_result = _has_dify_refs_hybrid(schema) | |||
| recursive_result = _has_dify_refs_recursive(schema) | |||
| assert hybrid_result == recursive_result, f"Mismatch for schema: {schema}" | |||
| def test_parse_dify_schema_uri(self): | |||
| """Test parse_dify_schema_uri function""" | |||
| # Valid URIs | |||
| assert parse_dify_schema_uri("https://dify.ai/schemas/v1/file.json") == ("v1", "file") | |||
| assert parse_dify_schema_uri("https://dify.ai/schemas/v2/complex_name.json") == ("v2", "complex_name") | |||
| assert parse_dify_schema_uri("https://dify.ai/schemas/v999/test-file.json") == ("v999", "test-file") | |||
| # Invalid URIs | |||
| assert parse_dify_schema_uri("https://example.com/schema.json") == ("", "") | |||
| assert parse_dify_schema_uri("invalid") == ("", "") | |||
| assert parse_dify_schema_uri("") == ("", "") | |||
| def test_remove_metadata_fields(self): | |||
| """Test _remove_metadata_fields function""" | |||
| schema = { | |||
| @@ -465,68 +401,68 @@ class TestUtilityFunctions: | |||
| "version": "should be removed", | |||
| "type": "object", | |||
| "title": "should remain", | |||
| "properties": {} | |||
| "properties": {}, | |||
| } | |||
| cleaned = _remove_metadata_fields(schema) | |||
| assert "$id" not in cleaned | |||
| assert "$schema" not in cleaned | |||
| assert "version" not in cleaned | |||
| assert cleaned["type"] == "object" | |||
| assert cleaned["title"] == "should remain" | |||
| assert "properties" in cleaned | |||
| # Original should be unchanged | |||
| assert "$id" in schema | |||
| class TestSchemaResolverClass: | |||
| """Test SchemaResolver class specifically""" | |||
| def test_resolver_initialization(self): | |||
| """Test resolver initialization""" | |||
| # Default initialization | |||
| resolver = SchemaResolver() | |||
| assert resolver.max_depth == 10 | |||
| assert resolver.registry is not None | |||
| # Custom initialization | |||
| custom_registry = MagicMock() | |||
| resolver = SchemaResolver(registry=custom_registry, max_depth=5) | |||
| assert resolver.max_depth == 5 | |||
| assert resolver.registry is custom_registry | |||
| def test_cache_sharing(self): | |||
| """Test that cache is shared between resolver instances""" | |||
| SchemaResolver.clear_cache() | |||
| schema = {"$ref": "https://dify.ai/schemas/v1/file.json"} | |||
| # First resolver populates cache | |||
| resolver1 = SchemaResolver() | |||
| result1 = resolver1.resolve(schema) | |||
| # Second resolver should use the same cache | |||
| resolver2 = SchemaResolver() | |||
| with patch.object(resolver2.registry, "get_schema") as mock_get: | |||
| result2 = resolver2.resolve(schema) | |||
| # Should not call registry since it's in cache | |||
| mock_get.assert_not_called() | |||
| assert result1 == result2 | |||
| def test_resolver_with_list_schema(self): | |||
| """Test resolver with list as root schema""" | |||
| list_schema = [ | |||
| {"$ref": "https://dify.ai/schemas/v1/file.json"}, | |||
| {"type": "string"}, | |||
| {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"} | |||
| {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}, | |||
| ] | |||
| resolver = SchemaResolver() | |||
| resolved = resolver.resolve(list_schema) | |||
| assert isinstance(resolved, list) | |||
| assert len(resolved) == 3 | |||
| assert resolved[0]["type"] == "object" | |||
| @@ -534,20 +470,20 @@ class TestSchemaResolverClass: | |||
| assert resolved[1] == {"type": "string"} | |||
| assert resolved[2]["type"] == "object" | |||
| assert resolved[2]["title"] == "Q&A Structure Schema" | |||
| def test_cache_performance(self): | |||
| """Test that caching improves performance""" | |||
| SchemaResolver.clear_cache() | |||
| # Create a schema with many references to the same schema | |||
| schema = { | |||
| "type": "object", | |||
| "properties": { | |||
| f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} | |||
| for i in range(50) # Reduced to avoid depth issues | |||
| } | |||
| }, | |||
| } | |||
| # First run (no cache) - run multiple times to warm up | |||
| results1 = [] | |||
| for _ in range(3): | |||
| @@ -556,9 +492,9 @@ class TestSchemaResolverClass: | |||
| result1 = resolve_dify_schema_refs(schema) | |||
| time_no_cache = time.perf_counter() - start | |||
| results1.append(time_no_cache) | |||
| avg_time_no_cache = sum(results1) / len(results1) | |||
| # Second run (with cache) - run multiple times | |||
| results2 = [] | |||
| for _ in range(3): | |||
| @@ -566,14 +502,14 @@ class TestSchemaResolverClass: | |||
| result2 = resolve_dify_schema_refs(schema) | |||
| time_with_cache = time.perf_counter() - start | |||
| results2.append(time_with_cache) | |||
| avg_time_with_cache = sum(results2) / len(results2) | |||
| # Cache should make it faster (more lenient check) | |||
| assert result1 == result2 | |||
| # Cache should provide some performance benefit | |||
| assert avg_time_with_cache <= avg_time_no_cache | |||
| def test_fast_path_performance_no_refs(self): | |||
| """Test that schemas without $refs use fast path and avoid deep copying""" | |||
| # Create a moderately complex schema without any $refs (typical plugin output_schema) | |||
| @@ -585,16 +521,13 @@ class TestSchemaResolverClass: | |||
| "properties": { | |||
| "name": {"type": "string"}, | |||
| "value": {"type": "number"}, | |||
| "items": { | |||
| "type": "array", | |||
| "items": {"type": "string"} | |||
| } | |||
| } | |||
| "items": {"type": "array", "items": {"type": "string"}}, | |||
| }, | |||
| } | |||
| for i in range(50) | |||
| } | |||
| }, | |||
| } | |||
| # Measure fast path (no refs) performance | |||
| fast_times = [] | |||
| for _ in range(10): | |||
| @@ -602,21 +535,21 @@ class TestSchemaResolverClass: | |||
| result_fast = resolve_dify_schema_refs(no_refs_schema) | |||
| elapsed = time.perf_counter() - start | |||
| fast_times.append(elapsed) | |||
| avg_fast_time = sum(fast_times) / len(fast_times) | |||
| # Most importantly: result should be identical to input (no copying) | |||
| assert result_fast is no_refs_schema | |||
| # Create schema with $refs for comparison (same structure size) | |||
| with_refs_schema = { | |||
| "type": "object", | |||
| "type": "object", | |||
| "properties": { | |||
| f"property_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} | |||
| for i in range(20) # Fewer to avoid depth issues but still comparable | |||
| } | |||
| }, | |||
| } | |||
| # Measure slow path (with refs) performance | |||
| SchemaResolver.clear_cache() | |||
| slow_times = [] | |||
| @@ -626,63 +559,54 @@ class TestSchemaResolverClass: | |||
| result_slow = resolve_dify_schema_refs(with_refs_schema, max_depth=50) | |||
| elapsed = time.perf_counter() - start | |||
| slow_times.append(elapsed) | |||
| avg_slow_time = sum(slow_times) / len(slow_times) | |||
| # The key benefit: fast path should be reasonably fast (main goal is no deep copy) | |||
| # and definitely avoid the expensive BFS resolution | |||
| # Even if detection has some overhead, it should still be faster for typical cases | |||
| print(f"Fast path (no refs): {avg_fast_time:.6f}s") | |||
| print(f"Slow path (with refs): {avg_slow_time:.6f}s") | |||
| # More lenient check: fast path should be at least somewhat competitive | |||
| # The main benefit is avoiding deep copy and BFS, not necessarily being 5x faster | |||
| assert avg_fast_time < avg_slow_time * 2 # Should not be more than 2x slower | |||
| def test_batch_processing_performance(self): | |||
| """Test performance improvement for batch processing of schemas without refs""" | |||
| # Simulate the plugin tool scenario: many schemas, most without refs | |||
| schemas_without_refs = [ | |||
| { | |||
| "type": "object", | |||
| "properties": { | |||
| f"field_{j}": {"type": "string" if j % 2 else "number"} | |||
| for j in range(10) | |||
| } | |||
| "properties": {f"field_{j}": {"type": "string" if j % 2 else "number"} for j in range(10)}, | |||
| } | |||
| for i in range(100) | |||
| ] | |||
| # Test batch processing performance | |||
| start = time.perf_counter() | |||
| results = [resolve_dify_schema_refs(schema) for schema in schemas_without_refs] | |||
| batch_time = time.perf_counter() - start | |||
| # Verify all results are identical to inputs (fast path used) | |||
| for original, result in zip(schemas_without_refs, results): | |||
| assert result is original | |||
| # Should be very fast - each schema should take < 0.001 seconds on average | |||
| avg_time_per_schema = batch_time / len(schemas_without_refs) | |||
| assert avg_time_per_schema < 0.001 | |||
| def test_has_dify_refs_performance(self): | |||
| """Test that _has_dify_refs is fast for large schemas without refs""" | |||
| # Create a very large schema without refs | |||
| large_schema = { | |||
| "type": "object", | |||
| "properties": {} | |||
| } | |||
| large_schema = {"type": "object", "properties": {}} | |||
| # Add many nested properties | |||
| current = large_schema | |||
| for i in range(100): | |||
| current["properties"][f"level_{i}"] = { | |||
| "type": "object", | |||
| "properties": {} | |||
| } | |||
| current["properties"][f"level_{i}"] = {"type": "object", "properties": {}} | |||
| current = current["properties"][f"level_{i}"] | |||
| # _has_dify_refs should be fast even for large schemas | |||
| times = [] | |||
| for _ in range(50): | |||
| @@ -690,13 +614,13 @@ class TestSchemaResolverClass: | |||
| has_refs = _has_dify_refs(large_schema) | |||
| elapsed = time.perf_counter() - start | |||
| times.append(elapsed) | |||
| avg_time = sum(times) / len(times) | |||
| # Should be False and fast | |||
| assert not has_refs | |||
| assert avg_time < 0.01 # Should complete in less than 10ms | |||
| def test_hybrid_vs_recursive_performance(self): | |||
| """Test performance comparison between hybrid and recursive detection""" | |||
| # Create test schemas of different types and sizes | |||
| @@ -704,16 +628,9 @@ class TestSchemaResolverClass: | |||
| # Case 1: Small schema without refs (most common case) | |||
| { | |||
| "name": "small_no_refs", | |||
| "schema": { | |||
| "type": "object", | |||
| "properties": { | |||
| "name": {"type": "string"}, | |||
| "value": {"type": "number"} | |||
| } | |||
| }, | |||
| "expected": False | |||
| "schema": {"type": "object", "properties": {"name": {"type": "string"}, "value": {"type": "number"}}}, | |||
| "expected": False, | |||
| }, | |||
| # Case 2: Medium schema without refs | |||
| { | |||
| "name": "medium_no_refs", | |||
| @@ -725,28 +642,16 @@ class TestSchemaResolverClass: | |||
| "properties": { | |||
| "name": {"type": "string"}, | |||
| "value": {"type": "number"}, | |||
| "items": { | |||
| "type": "array", | |||
| "items": {"type": "string"} | |||
| } | |||
| } | |||
| "items": {"type": "array", "items": {"type": "string"}}, | |||
| }, | |||
| } | |||
| for i in range(20) | |||
| } | |||
| }, | |||
| }, | |||
| "expected": False | |||
| "expected": False, | |||
| }, | |||
| # Case 3: Large schema without refs | |||
| { | |||
| "name": "large_no_refs", | |||
| "schema": { | |||
| "type": "object", | |||
| "properties": {} | |||
| }, | |||
| "expected": False | |||
| }, | |||
| {"name": "large_no_refs", "schema": {"type": "object", "properties": {}}, "expected": False}, | |||
| # Case 4: Schema with Dify refs | |||
| { | |||
| "name": "with_dify_refs", | |||
| @@ -754,45 +659,38 @@ class TestSchemaResolverClass: | |||
| "type": "object", | |||
| "properties": { | |||
| "file": {"$ref": "https://dify.ai/schemas/v1/file.json"}, | |||
| "data": {"type": "string"} | |||
| } | |||
| "data": {"type": "string"}, | |||
| }, | |||
| }, | |||
| "expected": True | |||
| "expected": True, | |||
| }, | |||
| # Case 5: Schema with non-Dify refs | |||
| { | |||
| "name": "with_external_refs", | |||
| "schema": { | |||
| "type": "object", | |||
| "properties": { | |||
| "external": {"$ref": "https://example.com/schema.json"}, | |||
| "data": {"type": "string"} | |||
| } | |||
| "type": "object", | |||
| "properties": {"external": {"$ref": "https://example.com/schema.json"}, "data": {"type": "string"}}, | |||
| }, | |||
| "expected": False | |||
| } | |||
| "expected": False, | |||
| }, | |||
| ] | |||
| # Add deep nesting to large schema | |||
| current = test_cases[2]["schema"] | |||
| for i in range(50): | |||
| current["properties"][f"level_{i}"] = { | |||
| "type": "object", | |||
| "properties": {} | |||
| } | |||
| current["properties"][f"level_{i}"] = {"type": "object", "properties": {}} | |||
| current = current["properties"][f"level_{i}"] | |||
| # Performance comparison | |||
| for test_case in test_cases: | |||
| schema = test_case["schema"] | |||
| expected = test_case["expected"] | |||
| name = test_case["name"] | |||
| # Test correctness first | |||
| assert _has_dify_refs_hybrid(schema) == expected | |||
| assert _has_dify_refs_recursive(schema) == expected | |||
| # Measure hybrid performance | |||
| hybrid_times = [] | |||
| for _ in range(10): | |||
| @@ -800,7 +698,7 @@ class TestSchemaResolverClass: | |||
| result_hybrid = _has_dify_refs_hybrid(schema) | |||
| elapsed = time.perf_counter() - start | |||
| hybrid_times.append(elapsed) | |||
| # Measure recursive performance | |||
| recursive_times = [] | |||
| for _ in range(10): | |||
| @@ -808,69 +706,62 @@ class TestSchemaResolverClass: | |||
| result_recursive = _has_dify_refs_recursive(schema) | |||
| elapsed = time.perf_counter() - start | |||
| recursive_times.append(elapsed) | |||
| avg_hybrid = sum(hybrid_times) / len(hybrid_times) | |||
| avg_recursive = sum(recursive_times) / len(recursive_times) | |||
| print(f"{name}: hybrid={avg_hybrid:.6f}s, recursive={avg_recursive:.6f}s") | |||
| # Results should be identical | |||
| assert result_hybrid == result_recursive == expected | |||
| # For schemas without refs, hybrid should be competitive or better | |||
| if not expected: # No refs case | |||
| # Hybrid might be slightly slower due to JSON serialization overhead, | |||
| # but should not be dramatically worse | |||
| assert avg_hybrid < avg_recursive * 5 # At most 5x slower | |||
| def test_string_matching_edge_cases(self): | |||
| """Test edge cases for string-based detection""" | |||
| # Case 1: False positive potential - $ref in description | |||
| schema_false_positive = { | |||
| "type": "object", | |||
| "properties": { | |||
| "description": { | |||
| "type": "string", | |||
| "description": "This field explains how $ref works in JSON Schema" | |||
| } | |||
| } | |||
| "description": {"type": "string", "description": "This field explains how $ref works in JSON Schema"} | |||
| }, | |||
| } | |||
| # Both methods should return False | |||
| assert not _has_dify_refs_hybrid(schema_false_positive) | |||
| assert not _has_dify_refs_recursive(schema_false_positive) | |||
| # Case 2: Complex URL patterns | |||
| complex_schema = { | |||
| "type": "object", | |||
| "properties": { | |||
| "config": { | |||
| "type": "object", | |||
| "type": "object", | |||
| "properties": { | |||
| "dify_url": { | |||
| "type": "string", | |||
| "default": "https://dify.ai/schemas/info" | |||
| }, | |||
| "actual_ref": { | |||
| "$ref": "https://dify.ai/schemas/v1/file.json" | |||
| } | |||
| } | |||
| "dify_url": {"type": "string", "default": "https://dify.ai/schemas/info"}, | |||
| "actual_ref": {"$ref": "https://dify.ai/schemas/v1/file.json"}, | |||
| }, | |||
| } | |||
| } | |||
| }, | |||
| } | |||
| # Both methods should return True (due to actual_ref) | |||
| assert _has_dify_refs_hybrid(complex_schema) | |||
| assert _has_dify_refs_recursive(complex_schema) | |||
| # Case 3: Non-JSON serializable objects (should fall back to recursive) | |||
| import datetime | |||
| non_serializable = { | |||
| "type": "object", | |||
| "timestamp": datetime.datetime.now(), | |||
| "data": {"$ref": "https://dify.ai/schemas/v1/file.json"} | |||
| "data": {"$ref": "https://dify.ai/schemas/v1/file.json"}, | |||
| } | |||
| # Hybrid should fall back to recursive and still work | |||
| assert _has_dify_refs_hybrid(non_serializable) | |||
| assert _has_dify_refs_recursive(non_serializable) | |||
| assert _has_dify_refs_recursive(non_serializable) | |||