Przeglądaj źródła

merge new graph engine

tags/2.0.0-beta.1
jyong 2 miesięcy temu
rodzic
commit
90d72f5ddf
35 zmienionych plików z 552 dodań i 617 usunięć
  1. 2
    2
      api/commands.py
  2. 1
    1
      api/controllers/console/datasets/rag_pipeline/datasource_auth.py
  3. 3
    2
      api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py
  4. 25
    39
      api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
  5. 4
    0
      api/controllers/console/datasets/wraps.py
  6. 1
    1
      api/controllers/console/spec.py
  7. 3
    0
      api/controllers/service_api/dataset/document.py
  8. 3
    1
      api/core/agent/base_agent_runner.py
  9. 1
    1
      api/core/app/apps/advanced_chat/app_generator.py
  10. 3
    1
      api/core/app/apps/chat/app_runner.py
  11. 1
    1
      api/core/app/apps/common/workflow_response_converter.py
  12. 45
    18
      api/core/app/apps/pipeline/pipeline_runner.py
  13. 1
    1
      api/core/plugin/impl/datasource.py
  14. 1
    1
      api/core/rag/index_processor/index_processor_base.py
  15. 1
    1
      api/core/schemas/__init__.py
  16. 21
    25
      api/core/schemas/registry.py
  17. 90
    101
      api/core/schemas/resolver.py
  18. 10
    13
      api/core/schemas/schema_manager.py
  19. 4
    4
      api/core/workflow/entities/variable_pool.py
  20. 4
    0
      api/core/workflow/enums.py
  21. 1
    1
      api/core/workflow/graph/graph.py
  22. 40
    29
      api/core/workflow/nodes/datasource/datasource_node.py
  23. 5
    9
      api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
  24. 3
    3
      api/models/dataset.py
  25. 5
    0
      api/models/provider_ids.py
  26. 15
    7
      api/services/dataset_service.py
  27. 2
    4
      api/services/datasource_provider_service.py
  28. 17
    3
      api/services/entities/knowledge_entities/rag_pipeline_entities.py
  29. 23
    29
      api/services/rag_pipeline/rag_pipeline.py
  30. 2
    2
      api/services/rag_pipeline/rag_pipeline_dsl_service.py
  31. 3
    2
      api/services/rag_pipeline/rag_pipeline_transform_service.py
  32. 4
    1
      api/tasks/batch_clean_document_task.py
  33. 15
    12
      api/tasks/rag_pipeline/rag_pipeline_run_task.py
  34. 1
    1
      api/tests/unit_tests/core/schemas/__init__.py
  35. 192
    301
      api/tests/unit_tests/core/schemas/test_resolver.py

+ 2
- 2
api/commands.py Wyświetl plik

@@ -14,7 +14,7 @@ from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from constants.languages import languages
from core.helper import encrypter
from core.plugin.entities.plugin import DatasourceProviderID, PluginInstallationSource
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
@@ -35,7 +35,7 @@ from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.provider_ids import ToolProviderID
from models.provider_ids import DatasourceProviderID, ToolProviderID
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService

+ 1
- 1
api/controllers/console/datasets/rag_pipeline/datasource_auth.py Wyświetl plik

@@ -11,10 +11,10 @@ from controllers.console.wraps import (
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin import DatasourceProviderID
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService


+ 3
- 2
api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py Wyświetl plik

@@ -17,10 +17,11 @@ from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import db
from models.account import Account
from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService
@@ -131,7 +132,7 @@ def _api_prerequisite(f):
@account_initialization_required
@get_rag_pipeline
def wrapper(*args, **kwargs):
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
return f(*args, **kwargs)


+ 25
- 39
api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py Wyświetl plik

@@ -62,7 +62,7 @@ class DraftRagPipelineApi(Resource):
Get draft rag pipeline's workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

# fetch draft workflow by app_model
@@ -84,7 +84,7 @@ class DraftRagPipelineApi(Resource):
Sync draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

content_type = request.headers.get("Content-Type", "")
@@ -161,7 +161,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
@@ -198,7 +198,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
@@ -235,7 +235,7 @@ class DraftRagPipelineRunApi(Resource):
Run draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
@@ -272,7 +272,7 @@ class PublishedRagPipelineRunApi(Resource):
Run published workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
@@ -384,8 +384,6 @@ class PublishedRagPipelineRunApi(Resource):
#
# return result
#


class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@setup_required
@login_required
@@ -396,7 +394,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
@@ -441,10 +439,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

parser = reqparse.RequestParser()
@@ -487,10 +482,7 @@ class RagPipelineDraftNodeRunApi(Resource):
Run draft workflow node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

parser = reqparse.RequestParser()
@@ -519,7 +511,7 @@ class RagPipelineTaskStopApi(Resource):
Stop workflow task
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@@ -538,7 +530,7 @@ class PublishedRagPipelineApi(Resource):
Get published pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not pipeline.is_published:
return None
@@ -558,10 +550,7 @@ class PublishedRagPipelineApi(Resource):
Publish workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

rag_pipeline_service = RagPipelineService()
@@ -595,7 +584,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

# Get default block configs
@@ -613,7 +602,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
@@ -659,7 +648,7 @@ class PublishedAllRagPipelineApi(Resource):
"""
Get published workflows
"""
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

parser = reqparse.RequestParser()
@@ -708,10 +697,7 @@ class RagPipelineByIdApi(Resource):
Update workflow attributes
"""
# Check permission
if not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

parser = reqparse.RequestParser()
@@ -767,7 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
@@ -792,7 +778,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
Get first step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
@@ -817,7 +803,7 @@ class DraftRagPipelineFirstStepApi(Resource):
Get first step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
@@ -842,7 +828,7 @@ class DraftRagPipelineSecondStepApi(Resource):
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
@@ -926,8 +912,11 @@ class DatasourceListApi(Resource):
@account_initialization_required
def get(self):
user = current_user

if not isinstance(user, Account):
raise Forbidden()
tenant_id = user.current_tenant_id
if not tenant_id:
raise Forbidden()

return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))

@@ -974,10 +963,7 @@ class RagPipelineDatasourceVariableApi(Resource):
"""
Set datasource variables
"""
if not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()

parser = reqparse.RequestParser()

+ 4
- 0
api/controllers/console/datasets/wraps.py Wyświetl plik

@@ -5,6 +5,7 @@ from typing import Optional
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models.account import Account
from models.dataset import Pipeline


@@ -17,6 +18,9 @@ def get_rag_pipeline(
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")

if not isinstance(current_user, Account):
raise ValueError("current_user is not an account")

pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id)


+ 1
- 1
api/controllers/console/spec.py Wyświetl plik

@@ -32,4 +32,4 @@ class SpecSchemaDefinitionsApi(Resource):
return [], 200


api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")

+ 3
- 0
api/controllers/service_api/dataset/document.py Wyświetl plik

@@ -133,6 +133,9 @@ class DocumentAddByTextApi(DatasetApiResource):
# validate args
DocumentService.document_create_args_validate(knowledge_config)

if not current_user:
raise ValueError("current_user is required")

try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,

+ 3
- 1
api/core/agent/base_agent_runner.py Wyświetl plik

@@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner):
tenant_id=tenant_id,
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_config.additional_features.show_retrieve_source,
return_resource=(
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
),
invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback,
user_id=user_id,

+ 1
- 1
api/core/app/apps/advanced_chat/app_generator.py Wyświetl plik

@@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):

if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True
app_config.additional_features.show_retrieve_source = True # type: ignore

workflow_run_id = str(uuid.uuid4())
# init application generate entity

+ 3
- 1
api/core/app/apps/chat/app_runner.py Wyświetl plik

@@ -162,7 +162,9 @@ class ChatAppRunner(AppRunner):
config=app_config.dataset,
query=query,
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source,
show_retrieve_source=(
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
),
hit_callback=hit_callback,
memory=memory,
message_id=message.id,

+ 1
- 1
api/core/app/apps/common/workflow_response_converter.py Wyświetl plik

@@ -36,8 +36,8 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.tools.entities.tool_entities import ToolProviderType
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution

+ 45
- 18
api/core/app/apps/pipeline/pipeline_runner.py Wyświetl plik

@@ -1,8 +1,7 @@
import logging
from collections.abc import Mapping
from typing import Any, Optional, cast
import time
from typing import Optional, cast

from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@@ -11,10 +10,12 @@ from core.app.entities.app_invoke_entities import (
RagPipelineGenerateEntity,
)
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@@ -22,7 +23,7 @@ from extensions.ext_database import db
from models.dataset import Document, Pipeline
from models.enums import UserFrom
from models.model import EndUser
from models.workflow import Workflow, WorkflowType
from models.workflow import Workflow

logger = logging.getLogger(__name__)

@@ -84,24 +85,30 @@ class PipelineRunner(WorkflowBasedAppRunner):

db.session.close()

workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())

# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
)
else:
inputs = self.application_generate_entity.inputs
@@ -121,6 +128,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
datasource_info=self.application_generate_entity.datasource_info,
invoke_from=self.application_generate_entity.invoke_from.value,
)

rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
@@ -143,11 +151,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
conversation_variables=[],
rag_pipeline_variables=rag_pipeline_variables,
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())

# init graph
graph = self._init_rag_pipeline_graph(
graph_config=workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
start_node_id=self.application_generate_entity.start_node_id,
workflow=workflow,
)

# RUN WORKFLOW
@@ -155,7 +165,6 @@ class PipelineRunner(WorkflowBasedAppRunner):
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
@@ -166,11 +175,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id,
graph_runtime_state=graph_runtime_state,
)

generator = workflow_entry.run(callbacks=workflow_callbacks)
generator = workflow_entry.run()

for event in generator:
self._update_document_status(
@@ -194,10 +202,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
# return workflow
return workflow

def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
def _init_rag_pipeline_graph(
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None
) -> Graph:
"""
Init pipeline graph
"""
graph_config = workflow.graph_dict
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")

@@ -227,7 +238,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph_config["nodes"] = real_run_nodes
graph_config["edges"] = real_edges
# init graph
graph = Graph.init(graph_config=graph_config)
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)

node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)

if not graph:
raise ValueError("graph not found in workflow")

+ 1
- 1
api/core/plugin/impl/datasource.py Wyświetl plik

@@ -10,13 +10,13 @@ from core.datasource.entities.datasource_entities import (
OnlineDriveDownloadFileRequest,
WebsiteCrawlMessage,
)
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
PluginDatasourceProviderEntity,
)
from core.plugin.impl.base import BasePluginClient
from core.schemas.resolver import resolve_dify_schema_refs
from models.provider_ids import DatasourceProviderID, GenericProviderID
from services.tools.tools_transform_service import ToolTransformService



+ 1
- 1
api/core/rag/index_processor/index_processor_base.py Wyświetl plik

@@ -2,7 +2,7 @@

from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any, TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

from configs import dify_config
from core.rag.extractor.entity.extract_setting import ExtractSetting

+ 1
- 1
api/core/schemas/__init__.py Wyświetl plik

@@ -2,4 +2,4 @@

from .resolver import resolve_dify_schema_refs

__all__ = ["resolve_dify_schema_refs"]
__all__ = ["resolve_dify_schema_refs"]

+ 21
- 25
api/core/schemas/registry.py Wyświetl plik

@@ -7,7 +7,7 @@ from typing import Any, ClassVar, Optional

class SchemaRegistry:
"""Schema registry manages JSON schemas with version support"""
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
_lock: ClassVar[threading.Lock] = threading.Lock()

@@ -25,41 +25,41 @@ class SchemaRegistry:
if cls._default_instance is None:
current_dir = Path(__file__).parent
schema_dir = current_dir / "builtin" / "schemas"
registry = cls(str(schema_dir))
registry.load_all_versions()
cls._default_instance = registry
return cls._default_instance

def load_all_versions(self) -> None:
"""Scans the schema directory and loads all versions"""
if not self.base_dir.exists():
return
for entry in self.base_dir.iterdir():
if not entry.is_dir():
continue
version = entry.name
if not version.startswith("v"):
continue
self._load_version_dir(version, entry)

def _load_version_dir(self, version: str, version_dir: Path) -> None:
"""Loads all schemas in a version directory"""
if not version_dir.exists():
return
if version not in self.versions:
self.versions[version] = {}
for entry in version_dir.iterdir():
if entry.suffix != ".json":
continue
schema_name = entry.stem
self._load_schema(version, schema_name, entry)

@@ -68,10 +68,10 @@ class SchemaRegistry:
try:
with open(schema_path, encoding="utf-8") as f:
schema = json.load(f)
# Store the schema
self.versions[version][schema_name] = schema
# Extract and store metadata
uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
metadata = {
@@ -81,26 +81,26 @@ class SchemaRegistry:
"deprecated": schema.get("deprecated", False),
}
self.metadata[uri] = metadata
except (OSError, json.JSONDecodeError) as e:
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")


def get_schema(self, uri: str) -> Optional[Any]:
"""Retrieves a schema by URI with version support"""
version, schema_name = self._parse_uri(uri)
if not version or not schema_name:
return None
version_schemas = self.versions.get(version)
if not version_schemas:
return None
return version_schemas.get(schema_name)

def _parse_uri(self, uri: str) -> tuple[str, str]:
"""Parses a schema URI to extract version and schema name"""
from core.schemas.resolver import parse_dify_schema_uri

return parse_dify_schema_uri(uri)

def list_versions(self) -> list[str]:
@@ -112,19 +112,15 @@ class SchemaRegistry:
version_schemas = self.versions.get(version)
if not version_schemas:
return []
return sorted(version_schemas.keys())

def get_all_schemas_for_version(self, version: str = "v1") -> list[Mapping[str, Any]]:
"""Returns all schemas for a version in the API format"""
version_schemas = self.versions.get(version, {})
result = []
for schema_name, schema in version_schemas.items():
result.append({
"name": schema_name,
"label": schema.get("title", schema_name),
"schema": schema
})
return result
result.append({"name": schema_name, "label": schema.get("title", schema_name), "schema": schema})

return result

+ 90
- 101
api/core/schemas/resolver.py Wyświetl plik

@@ -19,11 +19,13 @@ _DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$

class SchemaResolutionError(Exception):
"""Base exception for schema resolution errors"""

pass


class CircularReferenceError(SchemaResolutionError):
"""Raised when a circular reference is detected"""

def __init__(self, ref_uri: str, ref_path: list[str]):
self.ref_uri = ref_uri
self.ref_path = ref_path
@@ -32,6 +34,7 @@ class CircularReferenceError(SchemaResolutionError):

class MaxDepthExceededError(SchemaResolutionError):
"""Raised when maximum resolution depth is exceeded"""

def __init__(self, max_depth: int):
self.max_depth = max_depth
super().__init__(f"Maximum resolution depth ({max_depth}) exceeded")
@@ -39,6 +42,7 @@ class MaxDepthExceededError(SchemaResolutionError):

class SchemaNotFoundError(SchemaResolutionError):
"""Raised when a referenced schema cannot be found"""

def __init__(self, ref_uri: str):
self.ref_uri = ref_uri
super().__init__(f"Schema not found: {ref_uri}")
@@ -47,6 +51,7 @@ class SchemaNotFoundError(SchemaResolutionError):
@dataclass
class QueueItem:
"""Represents an item in the BFS queue"""

current: Any
parent: Optional[Any]
key: Optional[Union[str, int]]
@@ -56,39 +61,39 @@ class QueueItem:

class SchemaResolver:
"""Resolver for Dify schema references with caching and optimizations"""
_cache: dict[str, SchemaDict] = {}
_cache_lock = threading.Lock()
def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10):
"""
Initialize the schema resolver
Args:
registry: Schema registry to use (defaults to default registry)
max_depth: Maximum depth for reference resolution
"""
self.registry = registry or SchemaRegistry.default_registry()
self.max_depth = max_depth
@classmethod
def clear_cache(cls) -> None:
"""Clear the global schema cache"""
with cls._cache_lock:
cls._cache.clear()
def resolve(self, schema: SchemaType) -> SchemaType:
"""
Resolve all $ref references in the schema
Performance optimization: quickly checks for $ref presence before processing.
Args:
schema: Schema to resolve
Returns:
Resolved schema with all references expanded
Raises:
CircularReferenceError: If circular reference detected
MaxDepthExceededError: If max depth exceeded
@@ -96,44 +101,39 @@ class SchemaResolver:
"""
if not isinstance(schema, (dict, list)):
return schema
# Fast path: if no Dify refs found, return original schema unchanged
# This avoids expensive deepcopy and BFS traversal for schemas without refs
if not _has_dify_refs(schema):
return schema
# Slow path: schema contains refs, perform full resolution
import copy

result = copy.deepcopy(schema)
# Initialize BFS queue
queue = deque([QueueItem(
current=result,
parent=None,
key=None,
depth=0,
ref_path=set()
)])
queue = deque([QueueItem(current=result, parent=None, key=None, depth=0, ref_path=set())])

while queue:
item = queue.popleft()

# Process the current item
self._process_queue_item(queue, item)

return result
def _process_queue_item(self, queue: deque, item: QueueItem) -> None:
"""Process a single queue item"""
if isinstance(item.current, dict):
self._process_dict(queue, item)
elif isinstance(item.current, list):
self._process_list(queue, item)
def _process_dict(self, queue: deque, item: QueueItem) -> None:
"""Process a dictionary item"""
ref_uri = item.current.get("$ref")
if ref_uri and _is_dify_schema_ref(ref_uri):
# Handle $ref resolution
self._resolve_ref(queue, item, ref_uri)
@@ -144,14 +144,10 @@ class SchemaResolver:
next_depth = item.depth + 1
if next_depth >= self.max_depth:
raise MaxDepthExceededError(self.max_depth)
queue.append(QueueItem(
current=value,
parent=item.current,
key=key,
depth=next_depth,
ref_path=item.ref_path
))
queue.append(
QueueItem(current=value, parent=item.current, key=key, depth=next_depth, ref_path=item.ref_path)
)

def _process_list(self, queue: deque, item: QueueItem) -> None:
"""Process a list item"""
for idx, value in enumerate(item.current):
@@ -159,14 +155,10 @@ class SchemaResolver:
next_depth = item.depth + 1
if next_depth >= self.max_depth:
raise MaxDepthExceededError(self.max_depth)
queue.append(QueueItem(
current=value,
parent=item.current,
key=idx,
depth=next_depth,
ref_path=item.ref_path
))
queue.append(
QueueItem(current=value, parent=item.current, key=idx, depth=next_depth, ref_path=item.ref_path)
)

def _resolve_ref(self, queue: deque, item: QueueItem, ref_uri: str) -> None:
"""Resolve a $ref reference"""
# Check for circular reference
@@ -175,82 +167,78 @@ class SchemaResolver:
item.current["$circular_ref"] = True
logger.warning("Circular reference detected: %s", ref_uri)
return
# Get resolved schema (from cache or registry)
resolved_schema = self._get_resolved_schema(ref_uri)
if not resolved_schema:
logger.warning("Schema not found: %s", ref_uri)
return
# Update ref path
new_ref_path = item.ref_path | {ref_uri}
# Replace the reference with resolved schema
next_depth = item.depth + 1
if next_depth >= self.max_depth:
raise MaxDepthExceededError(self.max_depth)
if item.parent is None:
# Root level replacement
item.current.clear()
item.current.update(resolved_schema)
queue.append(QueueItem(
current=item.current,
parent=None,
key=None,
depth=next_depth,
ref_path=new_ref_path
))
queue.append(
QueueItem(current=item.current, parent=None, key=None, depth=next_depth, ref_path=new_ref_path)
)
else:
# Update parent container
item.parent[item.key] = resolved_schema.copy()
queue.append(QueueItem(
current=item.parent[item.key],
parent=item.parent,
key=item.key,
depth=next_depth,
ref_path=new_ref_path
))
queue.append(
QueueItem(
current=item.parent[item.key],
parent=item.parent,
key=item.key,
depth=next_depth,
ref_path=new_ref_path,
)
)

def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]:
"""Get resolved schema from cache or registry"""
# Check cache first
with self._cache_lock:
if ref_uri in self._cache:
return self._cache[ref_uri].copy()
# Fetch from registry
schema = self.registry.get_schema(ref_uri)
if not schema:
return None
# Clean and cache
cleaned = _remove_metadata_fields(schema)
with self._cache_lock:
self._cache[ref_uri] = cleaned
return cleaned.copy()


def resolve_dify_schema_refs(
schema: SchemaType,
registry: Optional[SchemaRegistry] = None,
max_depth: int = 30
schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30
) -> SchemaType:
"""
Resolve $ref references in Dify schema to actual schema content
This is a convenience function that creates a resolver and resolves the schema.
Performance optimization: quickly checks for $ref presence before processing.
Args:
schema: Schema object that may contain $ref references
registry: Optional schema registry, defaults to default registry
max_depth: Maximum depth to prevent infinite loops (default: 30)
Returns:
Schema with all $ref references resolved to actual content
Raises:
CircularReferenceError: If circular reference detected
MaxDepthExceededError: If maximum depth exceeded
@@ -260,7 +248,7 @@ def resolve_dify_schema_refs(
# This avoids expensive deepcopy and BFS traversal for schemas without refs
if not _has_dify_refs(schema):
return schema
# Slow path: schema contains refs, perform full resolution
resolver = SchemaResolver(registry, max_depth)
return resolver.resolve(schema)
@@ -269,36 +257,36 @@ def resolve_dify_schema_refs(
def _remove_metadata_fields(schema: dict) -> dict:
"""
Remove metadata fields from schema that shouldn't be included in resolved output
Args:
schema: Schema dictionary
Returns:
Cleaned schema without metadata fields
"""
# Create a copy and remove metadata fields
cleaned = schema.copy()
metadata_fields = ["$id", "$schema", "version"]
for field in metadata_fields:
cleaned.pop(field, None)
return cleaned


def _is_dify_schema_ref(ref_uri: Any) -> bool:
"""
Check if the reference URI is a Dify schema reference
Args:
ref_uri: URI to check
Returns:
True if it's a Dify schema reference
"""
if not isinstance(ref_uri, str):
return False
# Use pre-compiled pattern for better performance
return bool(_DIFY_SCHEMA_PATTERN.match(ref_uri))

@@ -306,12 +294,12 @@ def _is_dify_schema_ref(ref_uri: Any) -> bool:
def _has_dify_refs_recursive(schema: SchemaType) -> bool:
"""
Recursively check if a schema contains any Dify $ref references
This is the fallback method when string-based detection is not possible.
Args:
schema: Schema to check for references
Returns:
True if any Dify $ref is found, False otherwise
"""
@@ -320,18 +308,18 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
ref_uri = schema.get("$ref")
if ref_uri and _is_dify_schema_ref(ref_uri):
return True
# Check nested values
for value in schema.values():
if _has_dify_refs_recursive(value):
return True
elif isinstance(schema, list):
# Check each item in the list
for item in schema:
if _has_dify_refs_recursive(item):
return True
# Primitive types don't contain refs
return False

@@ -339,36 +327,37 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
def _has_dify_refs_hybrid(schema: SchemaType) -> bool:
"""
Hybrid detection: fast string scan followed by precise recursive check
Performance optimization using two-phase detection:
1. Fast string scan to quickly eliminate schemas without $ref
2. Precise recursive validation only for potential candidates
Args:
schema: Schema to check for references
Returns:
True if any Dify $ref is found, False otherwise
"""
# Phase 1: Fast string-based pre-filtering
try:
import json
schema_str = json.dumps(schema, separators=(',', ':'))

schema_str = json.dumps(schema, separators=(",", ":"))

# Quick elimination: no $ref at all
if '"$ref"' not in schema_str:
return False
# Quick elimination: no Dify schema URLs
if 'https://dify.ai/schemas/' not in schema_str:
if "https://dify.ai/schemas/" not in schema_str:
return False
except (TypeError, ValueError, OverflowError):
# JSON serialization failed (e.g., circular references, non-serializable objects)
# Fall back to recursive detection
logger.debug("JSON serialization failed for schema, using recursive detection")
return _has_dify_refs_recursive(schema)
# Phase 2: Precise recursive validation
# Only executed for schemas that passed string pre-filtering
return _has_dify_refs_recursive(schema)
@@ -377,14 +366,14 @@ def _has_dify_refs_hybrid(schema: SchemaType) -> bool:
def _has_dify_refs(schema: SchemaType) -> bool:
"""
Check if a schema contains any Dify $ref references
Uses hybrid detection for optimal performance:
- Fast string scan for quick elimination
- Fast string scan for quick elimination
- Precise recursive check for validation
Args:
schema: Schema to check for references
Returns:
True if any Dify $ref is found, False otherwise
"""
@@ -394,15 +383,15 @@ def _has_dify_refs(schema: SchemaType) -> bool:
def parse_dify_schema_uri(uri: str) -> tuple[str, str]:
"""
Parse a Dify schema URI to extract version and schema name
Args:
uri: Schema URI to parse
Returns:
Tuple of (version, schema_name) or ("", "") if invalid
"""
match = _DIFY_SCHEMA_PATTERN.match(uri)
if not match:
return "", ""
return match.group(1), match.group(2)
return match.group(1), match.group(2)

+ 10
- 13
api/core/schemas/schema_manager.py Wyświetl plik

@@ -13,10 +13,10 @@ class SchemaManager:
def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]:
"""
Get all JSON Schema definitions for a specific version
Args:
version: Schema version, defaults to v1
Returns:
Array containing schema definitions, each element contains name and schema fields
"""
@@ -25,31 +25,28 @@ class SchemaManager:
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]:
"""
Get a specific schema by name
Args:
schema_name: Schema name
version: Schema version, defaults to v1
Returns:
Dictionary containing name and schema, returns None if not found
"""
uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
schema = self.registry.get_schema(uri)
if schema:
return {
"name": schema_name,
"schema": schema
}
return {"name": schema_name, "schema": schema}
return None

def list_available_schemas(self, version: str = "v1") -> list[str]:
"""
List all available schema names for a specific version
Args:
version: Schema version, defaults to v1
Returns:
List of schema names
"""
@@ -58,8 +55,8 @@ class SchemaManager:
def list_available_versions(self) -> list[str]:
"""
List all available schema versions
Returns:
List of versions
"""
return self.registry.list_versions()
return self.registry.list_versions()

+ 4
- 4
api/core/workflow/entities/variable_pool.py Wyświetl plik

@@ -68,10 +68,10 @@ class VariablePool(BaseModel):
# Add rag pipeline variables to the variable pool
if self.rag_pipeline_variables:
rag_pipeline_variables_map = defaultdict(dict)
for var in self.rag_pipeline_variables:
node_id = var.variable.belong_to_node_id
key = var.variable.variable
value = var.value
for rag_var in self.rag_pipeline_variables:
node_id = rag_var.variable.belong_to_node_id
key = rag_var.variable.variable
value = rag_var.value
rag_pipeline_variables_map[node_id][key] = value
for key, value in rag_pipeline_variables_map.items():
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)

+ 4
- 0
api/core/workflow/enums.py Wyświetl plik

@@ -37,12 +37,14 @@ class NodeType(StrEnum):
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
KNOWLEDGE_INDEX = "knowledge-index"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
DATASOURCE = "datasource"
VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
@@ -83,6 +85,7 @@ class WorkflowType(StrEnum):

WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag-pipeline"


class WorkflowExecutionStatus(StrEnum):
@@ -116,6 +119,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"


class WorkflowNodeExecutionStatus(StrEnum):

+ 1
- 1
api/core/workflow/graph/graph.py Wyświetl plik

@@ -109,7 +109,7 @@ class Graph:
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid].get("data", {})
if node_data.get("type") == NodeType.START.value:
if node_data.get("type") in [NodeType.START, NodeType.DATASOURCE]:
start_node_id = nid
break


+ 40
- 29
api/core/workflow/nodes/datasource/datasource_node.py Wyświetl plik

@@ -19,16 +19,14 @@ from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile
@@ -39,7 +37,7 @@ from .entities import DatasourceNodeData
from .exc import DatasourceNodeError, DatasourceParameterError


class DatasourceNode(BaseNode):
class DatasourceNode(Node):
"""
Datasource Node
"""
@@ -97,8 +95,8 @@ class DatasourceNode(BaseNode):
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
except DatasourceNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
@@ -172,8 +170,8 @@ class DatasourceNode(BaseNode):
datasource_type=datasource_type,
)
case DatasourceProviderType.WEBSITE_CRAWL:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
@@ -204,10 +202,10 @@ class DatasourceNode(BaseNode):
size=upload_file.size,
storage_key=upload_file.key,
)
variable_pool.add([self.node_id, "file"], file_info)
variable_pool.add([self._node_id, "file"], file_info)
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
@@ -220,8 +218,8 @@ class DatasourceNode(BaseNode):
case _:
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
@@ -230,8 +228,8 @@ class DatasourceNode(BaseNode):
)
)
except DatasourceNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
@@ -425,8 +423,10 @@ class DatasourceNode(BaseNode):
elif message.type == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
@@ -442,7 +442,11 @@ class DatasourceNode(BaseNode):
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
@@ -454,17 +458,24 @@ class DatasourceNode(BaseNode):
variables[variable_name] = ""
variables[variable_name] += variable_value

yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])

yield RunCompletedEvent(
run_result=NodeRunResult(
# mark the end of the stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"json": json, "files": files, **variables, "text": text},
metadata={
@@ -526,9 +537,9 @@ class DatasourceNode(BaseNode):
tenant_id=self.tenant_id,
)
if file:
variable_pool.add([self.node_id, "file"], file)
yield RunCompletedEvent(
run_result=NodeRunResult(
variable_pool.add([self._node_id, "file"], file)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},

+ 5
- 9
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py Wyświetl plik

@@ -9,16 +9,15 @@ from sqlalchemy import func
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment

from ..base import BaseNode
from .entities import KnowledgeIndexNodeData
from .exc import (
KnowledgeIndexNodeError,
@@ -35,7 +34,7 @@ default_retrieval_model = {
}


class KnowledgeIndexNode(BaseNode):
class KnowledgeIndexNode(Node):
_node_data: KnowledgeIndexNodeData
_node_type = NodeType.KNOWLEDGE_INDEX

@@ -93,15 +92,12 @@ class KnowledgeIndexNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs=outputs,
)
results = self._invoke_knowledge_index(
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results)

except KnowledgeIndexNodeError as e:
logger.warning("Error when running knowledge index node")

+ 3
- 3
api/models/dataset.py Wyświetl plik

@@ -172,7 +172,7 @@ class Dataset(Base):
)

@property
def doc_form(self):
def doc_form(self) -> Optional[str]:
if self.chunk_structure:
return self.chunk_structure
document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
@@ -424,7 +424,7 @@ class Document(Base):
return status

@property
def data_source_info_dict(self):
def data_source_info_dict(self) -> dict[str, Any]:
if self.data_source_info:
try:
data_source_info_dict = json.loads(self.data_source_info)
@@ -432,7 +432,7 @@ class Document(Base):
data_source_info_dict = {}

return data_source_info_dict
return None
return {}

@property
def data_source_detail_dict(self):

+ 5
- 0
api/models/provider_ids.py Wyświetl plik

@@ -52,3 +52,8 @@ class ToolProviderID(GenericProviderID):
if self.organization == "langgenius":
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
self.plugin_name = f"{self.provider_name}_tool"


class DatasourceProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
super().__init__(value, is_hardcoded)

+ 15
- 7
api/services/dataset_service.py Wyświetl plik

@@ -718,9 +718,9 @@ class DatasetService:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_configuration.embedding_model_provider,
provider=knowledge_configuration.embedding_model_provider or "",
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.embedding_model,
model=knowledge_configuration.embedding_model or "",
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
@@ -1159,7 +1159,7 @@ class DocumentService:
return
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
file_ids = [
document.data_source_info_dict["upload_file_id"]
document.data_source_info_dict.get("upload_file_id", "")
for document in documents
if document.data_source_type == "upload_file"
]
@@ -1281,7 +1281,7 @@ class DocumentService:
account: Account | Any,
dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = "web",
):
) -> tuple[list[Document], str]:
# check doc_form
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
# check document limit
@@ -1386,7 +1386,7 @@ class DocumentService:
"Invalid process rule mode: %s, can not find dataset process rule",
process_rule.mode,
)
return
return [], ""
db.session.add(dataset_process_rule)
db.session.flush()
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
@@ -2595,7 +2595,9 @@ class SegmentService:
return segment_data_list

@classmethod
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
def update_segment(
cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset
) -> DocumentSegment:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
@@ -2764,6 +2766,8 @@ class SegmentService:
segment.error = str(e)
db.session.commit()
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
if not new_segment:
raise ValueError("new_segment is not found")
return new_segment

@classmethod
@@ -2804,7 +2808,11 @@ class SegmentService:
index_node_ids = [seg.index_node_id for seg in segments]
total_words = sum(seg.word_count for seg in segments)

document.word_count -= total_words
if document.word_count is None:
document.word_count = 0
else:
document.word_count = max(0, document.word_count - total_words)

db.session.add(document)

delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)

+ 2
- 4
api/services/datasource_provider_service.py Wyświetl plik

@@ -11,7 +11,6 @@ from core.helper import encrypter
from core.helper.name_generator import generate_incremental_name
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.model_runtime.entities.provider_entities import FormType
from core.plugin.entities.plugin import DatasourceProviderID
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
@@ -19,6 +18,7 @@ from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncry
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
from models.provider_ids import DatasourceProviderID
from services.plugin.plugin_service import PluginService

logger = logging.getLogger(__name__)
@@ -809,9 +809,7 @@ class DatasourceProviderService:
credentials = self.list_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
datasource_credentials.append(
{
"provider": datasource.provider,

+ 17
- 3
api/services/entities/knowledge_entities/rag_pipeline_entities.py Wyświetl plik

@@ -1,6 +1,6 @@
from typing import Literal, Optional

from pydantic import BaseModel
from pydantic import BaseModel, field_validator


class IconInfo(BaseModel):
@@ -110,7 +110,21 @@ class KnowledgeConfiguration(BaseModel):

chunk_structure: str
indexing_technique: Literal["high_quality", "economy"]
embedding_model_provider: Optional[str] = ""
embedding_model: Optional[str] = ""
embedding_model_provider: str = ""
embedding_model: str = ""
keyword_number: Optional[int] = 10
retrieval_model: RetrievalSetting

@field_validator("embedding_model_provider", mode="before")
@classmethod
def validate_embedding_model_provider(cls, v):
if v is None:
return ""
return v

@field_validator("embedding_model", mode="before")
@classmethod
def validate_embedding_model(cls, v):
if v is None:
return ""
return v

+ 23
- 29
api/services/rag_pipeline/rag_pipeline.py Wyświetl plik

@@ -28,26 +28,23 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.rag.entities.event import (
BaseDatasourceEvent,
DatasourceCompletedEvent,
DatasourceErrorEvent,
DatasourceProcessingEvent,
)
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
)
from core.workflow.enums import SystemVariableKey
from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.graph_events.base import GraphNodeEventBase
from core.workflow.node_events.base import NodeRunResult
from core.workflow.node_events.node import StreamCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from core.workflow.system_variable import SystemVariable
@@ -105,12 +102,13 @@ class RagPipelineService:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
return built_in_result
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
return result
customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
return customized_result

@classmethod
def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
@@ -471,7 +469,7 @@ class RagPipelineService:
datasource_type: str,
is_published: bool,
credential_id: Optional[str] = None,
) -> Generator[BaseDatasourceEvent, None, None]:
) -> Generator[Mapping[str, Any], None, None]:
"""
Run published workflow datasource
"""
@@ -563,9 +561,9 @@ class RagPipelineService:
user_id=account.id,
request=OnlineDriveBrowseFilesRequest(
bucket=user_inputs.get("bucket"),
prefix=user_inputs.get("prefix"),
prefix=user_inputs.get("prefix", ""),
max_keys=user_inputs.get("max_keys", 20),
start_after=user_inputs.get("start_after"),
next_page_parameters=user_inputs.get("next_page_parameters"),
),
provider_type=datasource_runtime.datasource_provider_type(),
)
@@ -600,7 +598,7 @@ class RagPipelineService:
end_time = time.time()
if message.result.status == "completed":
crawl_event = DatasourceCompletedEvent(
data=message.result.web_info_list,
data=message.result.web_info_list or [],
total=message.result.total,
completed=message.result.completed,
time_consuming=round(end_time - start_time, 2),
@@ -681,9 +679,9 @@ class RagPipelineService:
datasource_runtime.get_online_document_page_content(
user_id=account.id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=user_inputs.get("workspace_id"),
page_id=user_inputs.get("page_id"),
type=user_inputs.get("type"),
workspace_id=user_inputs.get("workspace_id", ""),
page_id=user_inputs.get("page_id", ""),
type=user_inputs.get("type", ""),
),
provider_type=datasource_type,
)
@@ -740,7 +738,7 @@ class RagPipelineService:

def _handle_node_run_result(
self,
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
getter: Callable[[], tuple[Node, Generator[GraphNodeEventBase, None, None]]],
start_at: float,
tenant_id: str,
node_id: str,
@@ -758,17 +756,16 @@ class RagPipelineService:

node_run_result: NodeRunResult | None = None
for event in generator:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result

if isinstance(event, StreamCompletedEvent):
node_run_result = event.node_run_result
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
break

if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.continue_on_error:
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.error_strategy:
node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
@@ -808,7 +805,7 @@ class RagPipelineService:
workflow_id=node_instance.workflow_id,
index=1,
node_id=node_id,
node_type=node_instance.type_,
node_type=node_instance.node_type,
title=node_instance.title,
elapsed_time=time.perf_counter() - start_at,
finished_at=datetime.now(UTC).replace(tzinfo=None),
@@ -1148,7 +1145,7 @@ class RagPipelineService:
.first()
)
return node_exec
def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account | EndUser):
# fetch draft workflow by app_model
draft_workflow = self.get_draft_workflow(pipeline=pipeline)
@@ -1208,6 +1205,3 @@ class RagPipelineService:
)
session.commit()
return workflow_node_execution_db_model


+ 2
- 2
api/services/rag_pipeline/rag_pipeline_dsl_service.py Wyświetl plik

@@ -23,8 +23,8 @@ from core.helper import ssrf_proxy
from core.helper.name_generator import generate_incremental_name
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import PluginDependency
from core.workflow.enums import NodeType
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
@@ -281,7 +281,7 @@ class RagPipelineDslService:
icon = icon_info.icon
icon_background = icon_info.icon_background
icon_url = icon_info.icon_url
else:
else:
icon_type = data.get("rag_pipeline", {}).get("icon_type")
icon = data.get("rag_pipeline", {}).get("icon")
icon_background = data.get("rag_pipeline", {}).get("icon_background")

+ 3
- 2
api/services/rag_pipeline/rag_pipeline_transform_service.py Wyświetl plik

@@ -1,6 +1,7 @@
import json
from datetime import UTC, datetime
from pathlib import Path
from typing import Optional
from uuid import uuid4

import yaml
@@ -87,7 +88,7 @@ class RagPipelineTransformService:
"status": "success",
}

def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str):
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: Optional[str]):
if doc_form == "text_model":
match datasource_type:
case "upload_file":
@@ -148,7 +149,7 @@ class RagPipelineTransformService:
return node

def _deal_knowledge_index(
self, dataset: Dataset, doc_form: str, indexing_technique: str, retrieval_model: dict, node: dict
self, dataset: Dataset, doc_form: str, indexing_technique: Optional[str], retrieval_model: dict, node: dict
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)

+ 4
- 1
api/tasks/batch_clean_document_task.py Wyświetl plik

@@ -1,5 +1,6 @@
import logging
import time
from typing import Optional

import click
from celery import shared_task
@@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)


@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: Optional[str], file_ids: list[str]):
"""
Clean document when document deleted.
:param document_ids: document ids
@@ -29,6 +30,8 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
start_at = time.perf_counter()

try:
if not doc_form:
raise ValueError("doc_form is required")
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()

if not dataset:

+ 15
- 12
api/tasks/rag_pipeline/rag_pipeline_run_task.py Wyświetl plik

@@ -21,14 +21,16 @@ from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom


@shared_task(queue="dataset")
def rag_pipeline_run_task(pipeline_id: str,
application_generate_entity: dict,
user_id: str,
tenant_id: str,
workflow_id: str,
streaming: bool,
workflow_execution_id: str | None = None,
workflow_thread_pool_id: str | None = None):
def rag_pipeline_run_task(
pipeline_id: str,
application_generate_entity: dict,
user_id: str,
tenant_id: str,
workflow_id: str,
streaming: bool,
workflow_execution_id: str | None = None,
workflow_thread_pool_id: str | None = None,
):
"""
Async Run rag pipeline
:param pipeline_id: Pipeline ID
@@ -94,18 +96,19 @@ def rag_pipeline_run_task(pipeline_id: str,
with current_app.app_context():
# Set the user directly in g for preserve_flask_contexts
g._login_user = account
# Copy context for thread (after setting user)
context = contextvars.copy_context()
# Get Flask app object in the main thread where app context exists
flask_app = current_app._get_current_object() # type: ignore
# Create a wrapper function that passes user context
def _run_with_user_context():
# Don't create a new app context here - let _generate handle it
# Just ensure the user is available in contextvars
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator

pipeline_generator = PipelineGenerator()
pipeline_generator._generate(
flask_app=flask_app,
@@ -120,7 +123,7 @@ def rag_pipeline_run_task(pipeline_id: str,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
# Create and start worker thread
worker_thread = threading.Thread(target=_run_with_user_context)
worker_thread.start()

+ 1
- 1
api/tests/unit_tests/core/schemas/__init__.py Wyświetl plik

@@ -1 +1 @@
# Core schemas unit tests
# Core schemas unit tests

+ 192
- 301
api/tests/unit_tests/core/schemas/test_resolver.py Wyświetl plik

@@ -33,18 +33,16 @@ class TestSchemaResolver:

def test_simple_ref_resolution(self):
"""Test resolving a simple $ref to a complete schema"""
schema_with_ref = {
"$ref": "https://dify.ai/schemas/v1/qa_structure.json"
}
schema_with_ref = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}

resolved = resolve_dify_schema_refs(schema_with_ref)

# Should be resolved to the actual qa_structure schema
assert resolved["type"] == "object"
assert resolved["title"] == "Q&A Structure Schema"
assert "qa_chunks" in resolved["properties"]
assert resolved["properties"]["qa_chunks"]["type"] == "array"
# Metadata fields should be removed
assert "$id" not in resolved
assert "$schema" not in resolved
@@ -55,29 +53,24 @@ class TestSchemaResolver:
nested_schema = {
"type": "object",
"properties": {
"file_data": {
"$ref": "https://dify.ai/schemas/v1/file.json"
},
"metadata": {
"type": "string",
"description": "Additional metadata"
}
}
"file_data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
"metadata": {"type": "string", "description": "Additional metadata"},
},
}

resolved = resolve_dify_schema_refs(nested_schema)

# Original structure should be preserved
assert resolved["type"] == "object"
assert "metadata" in resolved["properties"]
assert resolved["properties"]["metadata"]["type"] == "string"
# $ref should be resolved
file_schema = resolved["properties"]["file_data"]
assert file_schema["type"] == "object"
assert file_schema["title"] == "File Schema"
assert "name" in file_schema["properties"]
# Metadata fields should be removed from resolved schema
assert "$id" not in file_schema
assert "$schema" not in file_schema
@@ -87,18 +80,16 @@ class TestSchemaResolver:
"""Test resolving $refs in array items"""
array_schema = {
"type": "array",
"items": {
"$ref": "https://dify.ai/schemas/v1/general_structure.json"
},
"description": "Array of general structures"
"items": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"},
"description": "Array of general structures",
}

resolved = resolve_dify_schema_refs(array_schema)
# Array structure should be preserved
assert resolved["type"] == "array"
assert resolved["description"] == "Array of general structures"
# Items $ref should be resolved
items_schema = resolved["items"]
assert items_schema["type"] == "array"
@@ -109,20 +100,16 @@ class TestSchemaResolver:
external_ref_schema = {
"type": "object",
"properties": {
"external_data": {
"$ref": "https://example.com/external-schema.json"
},
"dify_data": {
"$ref": "https://dify.ai/schemas/v1/file.json"
}
}
"external_data": {"$ref": "https://example.com/external-schema.json"},
"dify_data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
},
}

resolved = resolve_dify_schema_refs(external_ref_schema)

# External $ref should remain unchanged
assert resolved["properties"]["external_data"]["$ref"] == "https://example.com/external-schema.json"
# Dify $ref should be resolved
assert resolved["properties"]["dify_data"]["type"] == "object"
assert resolved["properties"]["dify_data"]["title"] == "File Schema"
@@ -132,22 +119,14 @@ class TestSchemaResolver:
simple_schema = {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name field"
},
"items": {
"type": "array",
"items": {
"type": "number"
}
}
"name": {"type": "string", "description": "Name field"},
"items": {"type": "array", "items": {"type": "number"}},
},
"required": ["name"]
"required": ["name"],
}

resolved = resolve_dify_schema_refs(simple_schema)

# Should be identical to input
assert resolved == simple_schema
assert resolved["type"] == "object"
@@ -159,21 +138,16 @@ class TestSchemaResolver:
"""Test that excessive recursion depth is prevented"""
# Create a moderately nested structure
deep_schema = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
# Wrap it in fewer layers to make the test more reasonable
for _ in range(2):
deep_schema = {
"type": "object",
"properties": {
"nested": deep_schema
}
}
deep_schema = {"type": "object", "properties": {"nested": deep_schema}}

# Should handle normal cases fine with reasonable depth
resolved = resolve_dify_schema_refs(deep_schema, max_depth=25)
assert resolved is not None
assert resolved["type"] == "object"

# Should raise error with very low max_depth
with pytest.raises(MaxDepthExceededError) as exc_info:
resolve_dify_schema_refs(deep_schema, max_depth=5)
@@ -185,12 +159,12 @@ class TestSchemaResolver:
mock_registry = MagicMock()
mock_registry.get_schema.side_effect = lambda uri: {
"$ref": "https://dify.ai/schemas/v1/circular.json",
"type": "object"
"type": "object",
}
schema = {"$ref": "https://dify.ai/schemas/v1/circular.json"}
resolved = resolve_dify_schema_refs(schema, registry=mock_registry)
# Should mark circular reference
assert "$circular_ref" in resolved

@@ -199,10 +173,10 @@ class TestSchemaResolver:
# Mock registry that returns None for unknown schemas
mock_registry = MagicMock()
mock_registry.get_schema.return_value = None
schema = {"$ref": "https://dify.ai/schemas/v1/unknown.json"}
resolved = resolve_dify_schema_refs(schema, registry=mock_registry)
# Should keep the original $ref when schema not found
assert resolved["$ref"] == "https://dify.ai/schemas/v1/unknown.json"

@@ -217,25 +191,25 @@ class TestSchemaResolver:
def test_cache_functionality(self):
"""Test that caching works correctly"""
schema = {"$ref": "https://dify.ai/schemas/v1/file.json"}
# First resolution should fetch from registry
resolved1 = resolve_dify_schema_refs(schema)
# Mock the registry to return different data
with patch.object(self.registry, "get_schema") as mock_get:
mock_get.return_value = {"type": "different"}
# Second resolution should use cache
resolved2 = resolve_dify_schema_refs(schema)
# Should be the same as first resolution (from cache)
assert resolved1 == resolved2
# Mock should not have been called
mock_get.assert_not_called()
# Clear cache and try again
SchemaResolver.clear_cache()
# Now it should fetch again
resolved3 = resolve_dify_schema_refs(schema)
assert resolved3 == resolved1
@@ -244,14 +218,11 @@ class TestSchemaResolver:
"""Test that the resolver is thread-safe"""
schema = {
"type": "object",
"properties": {
f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
for i in range(10)
}
"properties": {f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} for i in range(10)},
}

results = []
def resolve_in_thread():
try:
result = resolve_dify_schema_refs(schema)
@@ -260,12 +231,12 @@ class TestSchemaResolver:
except Exception as e:
results.append(e)
return False
# Run multiple threads concurrently
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(resolve_in_thread) for _ in range(20)]
success = all(f.result() for f in futures)
assert success
# All results should be the same
first_result = results[0]
@@ -276,10 +247,7 @@ class TestSchemaResolver:
complex_schema = {
"type": "object",
"properties": {
"files": {
"type": "array",
"items": {"$ref": "https://dify.ai/schemas/v1/file.json"}
},
"files": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}},
"nested": {
"type": "object",
"properties": {
@@ -290,21 +258,21 @@ class TestSchemaResolver:
"type": "object",
"properties": {
"general": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"}
}
}
}
}
}
}
},
},
},
},
},
},
}
resolved = resolve_dify_schema_refs(complex_schema, max_depth=20)
# Check structure is preserved
assert resolved["type"] == "object"
assert "files" in resolved["properties"]
assert "nested" in resolved["properties"]
# Check refs are resolved
assert resolved["properties"]["files"]["items"]["type"] == "object"
assert resolved["properties"]["files"]["items"]["title"] == "File Schema"
@@ -314,14 +282,14 @@ class TestSchemaResolver:

class TestUtilityFunctions:
"""Test utility functions"""
def test_is_dify_schema_ref(self):
"""Test _is_dify_schema_ref function"""
# Valid Dify refs
assert _is_dify_schema_ref("https://dify.ai/schemas/v1/file.json")
assert _is_dify_schema_ref("https://dify.ai/schemas/v2/complex_name.json")
assert _is_dify_schema_ref("https://dify.ai/schemas/v999/test-file.json")
# Invalid refs
assert not _is_dify_schema_ref("https://example.com/schema.json")
assert not _is_dify_schema_ref("https://dify.ai/other/path.json")
@@ -330,61 +298,46 @@ class TestUtilityFunctions:
assert not _is_dify_schema_ref(None)
assert not _is_dify_schema_ref(123)
assert not _is_dify_schema_ref(["list"])
def test_has_dify_refs(self):
"""Test _has_dify_refs function"""
# Schemas with Dify refs
assert _has_dify_refs({"$ref": "https://dify.ai/schemas/v1/file.json"})
assert _has_dify_refs({
"type": "object",
"properties": {
"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}
}
})
assert _has_dify_refs([
{"type": "string"},
{"$ref": "https://dify.ai/schemas/v1/file.json"}
])
assert _has_dify_refs({
"type": "array",
"items": {
"type": "object",
"properties": {
"nested": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
}
assert _has_dify_refs(
{"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}}
)
assert _has_dify_refs([{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/file.json"}])
assert _has_dify_refs(
{
"type": "array",
"items": {
"type": "object",
"properties": {"nested": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}},
},
}
})
)

# Schemas without Dify refs
assert not _has_dify_refs({"type": "string"})
assert not _has_dify_refs({
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"}
}
})
assert not _has_dify_refs([
{"type": "string"},
{"type": "number"},
{"type": "object", "properties": {"name": {"type": "string"}}}
])
assert not _has_dify_refs(
{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}}
)
assert not _has_dify_refs(
[{"type": "string"}, {"type": "number"}, {"type": "object", "properties": {"name": {"type": "string"}}}]
)

# Schemas with non-Dify refs (should return False)
assert not _has_dify_refs({"$ref": "https://example.com/schema.json"})
assert not _has_dify_refs({
"type": "object",
"properties": {
"external": {"$ref": "https://example.com/external.json"}
}
})
assert not _has_dify_refs(
{"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}}
)

# Primitive types
assert not _has_dify_refs("string")
assert not _has_dify_refs(123)
assert not _has_dify_refs(True)
assert not _has_dify_refs(None)
def test_has_dify_refs_hybrid_vs_recursive(self):
"""Test that hybrid and recursive detection give same results"""
test_schemas = [
@@ -392,29 +345,13 @@ class TestUtilityFunctions:
{"type": "string"},
{"type": "object", "properties": {"name": {"type": "string"}}},
[{"type": "string"}, {"type": "number"}],
# With Dify refs
# With Dify refs
{"$ref": "https://dify.ai/schemas/v1/file.json"},
{
"type": "object",
"properties": {
"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}
}
},
[
{"type": "string"},
{"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
],
{"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}},
[{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}],
# With non-Dify refs
{"$ref": "https://example.com/schema.json"},
{
"type": "object",
"properties": {
"external": {"$ref": "https://example.com/external.json"}
}
},
{"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}},
# Complex nested
{
"type": "object",
@@ -422,41 +359,40 @@ class TestUtilityFunctions:
"level1": {
"type": "object",
"properties": {
"level2": {
"type": "array",
"items": {"$ref": "https://dify.ai/schemas/v1/file.json"}
}
}
"level2": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}}
},
}
}
},
},
# Edge cases
{"description": "This mentions $ref but is not a reference"},
{"$ref": "not-a-url"},
# Primitive types
"string", 123, True, None, []
"string",
123,
True,
None,
[],
]
for schema in test_schemas:
hybrid_result = _has_dify_refs_hybrid(schema)
recursive_result = _has_dify_refs_recursive(schema)
assert hybrid_result == recursive_result, f"Mismatch for schema: {schema}"
def test_parse_dify_schema_uri(self):
"""Test parse_dify_schema_uri function"""
# Valid URIs
assert parse_dify_schema_uri("https://dify.ai/schemas/v1/file.json") == ("v1", "file")
assert parse_dify_schema_uri("https://dify.ai/schemas/v2/complex_name.json") == ("v2", "complex_name")
assert parse_dify_schema_uri("https://dify.ai/schemas/v999/test-file.json") == ("v999", "test-file")
# Invalid URIs
assert parse_dify_schema_uri("https://example.com/schema.json") == ("", "")
assert parse_dify_schema_uri("invalid") == ("", "")
assert parse_dify_schema_uri("") == ("", "")
def test_remove_metadata_fields(self):
"""Test _remove_metadata_fields function"""
schema = {
@@ -465,68 +401,68 @@ class TestUtilityFunctions:
"version": "should be removed",
"type": "object",
"title": "should remain",
"properties": {}
"properties": {},
}
cleaned = _remove_metadata_fields(schema)
assert "$id" not in cleaned
assert "$schema" not in cleaned
assert "version" not in cleaned
assert cleaned["type"] == "object"
assert cleaned["title"] == "should remain"
assert "properties" in cleaned
# Original should be unchanged
assert "$id" in schema


class TestSchemaResolverClass:
"""Test SchemaResolver class specifically"""
def test_resolver_initialization(self):
"""Test resolver initialization"""
# Default initialization
resolver = SchemaResolver()
assert resolver.max_depth == 10
assert resolver.registry is not None
# Custom initialization
custom_registry = MagicMock()
resolver = SchemaResolver(registry=custom_registry, max_depth=5)
assert resolver.max_depth == 5
assert resolver.registry is custom_registry
def test_cache_sharing(self):
"""Test that cache is shared between resolver instances"""
SchemaResolver.clear_cache()
schema = {"$ref": "https://dify.ai/schemas/v1/file.json"}
# First resolver populates cache
resolver1 = SchemaResolver()
result1 = resolver1.resolve(schema)
# Second resolver should use the same cache
resolver2 = SchemaResolver()
with patch.object(resolver2.registry, "get_schema") as mock_get:
result2 = resolver2.resolve(schema)
# Should not call registry since it's in cache
mock_get.assert_not_called()
assert result1 == result2
def test_resolver_with_list_schema(self):
"""Test resolver with list as root schema"""
list_schema = [
{"$ref": "https://dify.ai/schemas/v1/file.json"},
{"type": "string"},
{"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
{"$ref": "https://dify.ai/schemas/v1/qa_structure.json"},
]
resolver = SchemaResolver()
resolved = resolver.resolve(list_schema)
assert isinstance(resolved, list)
assert len(resolved) == 3
assert resolved[0]["type"] == "object"
@@ -534,20 +470,20 @@ class TestSchemaResolverClass:
assert resolved[1] == {"type": "string"}
assert resolved[2]["type"] == "object"
assert resolved[2]["title"] == "Q&A Structure Schema"
def test_cache_performance(self):
"""Test that caching improves performance"""
SchemaResolver.clear_cache()
# Create a schema with many references to the same schema
schema = {
"type": "object",
"properties": {
f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
for i in range(50) # Reduced to avoid depth issues
}
},
}
# First run (no cache) - run multiple times to warm up
results1 = []
for _ in range(3):
@@ -556,9 +492,9 @@ class TestSchemaResolverClass:
result1 = resolve_dify_schema_refs(schema)
time_no_cache = time.perf_counter() - start
results1.append(time_no_cache)
avg_time_no_cache = sum(results1) / len(results1)
# Second run (with cache) - run multiple times
results2 = []
for _ in range(3):
@@ -566,14 +502,14 @@ class TestSchemaResolverClass:
result2 = resolve_dify_schema_refs(schema)
time_with_cache = time.perf_counter() - start
results2.append(time_with_cache)
avg_time_with_cache = sum(results2) / len(results2)
# Cache should make it faster (more lenient check)
assert result1 == result2
# Cache should provide some performance benefit
assert avg_time_with_cache <= avg_time_no_cache
def test_fast_path_performance_no_refs(self):
"""Test that schemas without $refs use fast path and avoid deep copying"""
# Create a moderately complex schema without any $refs (typical plugin output_schema)
@@ -585,16 +521,13 @@ class TestSchemaResolverClass:
"properties": {
"name": {"type": "string"},
"value": {"type": "number"},
"items": {
"type": "array",
"items": {"type": "string"}
}
}
"items": {"type": "array", "items": {"type": "string"}},
},
}
for i in range(50)
}
},
}
# Measure fast path (no refs) performance
fast_times = []
for _ in range(10):
@@ -602,21 +535,21 @@ class TestSchemaResolverClass:
result_fast = resolve_dify_schema_refs(no_refs_schema)
elapsed = time.perf_counter() - start
fast_times.append(elapsed)
avg_fast_time = sum(fast_times) / len(fast_times)
# Most importantly: result should be identical to input (no copying)
assert result_fast is no_refs_schema
# Create schema with $refs for comparison (same structure size)
with_refs_schema = {
"type": "object",
"type": "object",
"properties": {
f"property_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
for i in range(20) # Fewer to avoid depth issues but still comparable
}
},
}
# Measure slow path (with refs) performance
SchemaResolver.clear_cache()
slow_times = []
@@ -626,63 +559,54 @@ class TestSchemaResolverClass:
result_slow = resolve_dify_schema_refs(with_refs_schema, max_depth=50)
elapsed = time.perf_counter() - start
slow_times.append(elapsed)
avg_slow_time = sum(slow_times) / len(slow_times)
# The key benefit: fast path should be reasonably fast (main goal is no deep copy)
# and definitely avoid the expensive BFS resolution
# Even if detection has some overhead, it should still be faster for typical cases
print(f"Fast path (no refs): {avg_fast_time:.6f}s")
print(f"Slow path (with refs): {avg_slow_time:.6f}s")
# More lenient check: fast path should be at least somewhat competitive
# The main benefit is avoiding deep copy and BFS, not necessarily being 5x faster
assert avg_fast_time < avg_slow_time * 2 # Should not be more than 2x slower
def test_batch_processing_performance(self):
"""Test performance improvement for batch processing of schemas without refs"""
# Simulate the plugin tool scenario: many schemas, most without refs
schemas_without_refs = [
{
"type": "object",
"properties": {
f"field_{j}": {"type": "string" if j % 2 else "number"}
for j in range(10)
}
"properties": {f"field_{j}": {"type": "string" if j % 2 else "number"} for j in range(10)},
}
for i in range(100)
]
# Test batch processing performance
start = time.perf_counter()
results = [resolve_dify_schema_refs(schema) for schema in schemas_without_refs]
batch_time = time.perf_counter() - start
# Verify all results are identical to inputs (fast path used)
for original, result in zip(schemas_without_refs, results):
assert result is original
# Should be very fast - each schema should take < 0.001 seconds on average
avg_time_per_schema = batch_time / len(schemas_without_refs)
assert avg_time_per_schema < 0.001
def test_has_dify_refs_performance(self):
"""Test that _has_dify_refs is fast for large schemas without refs"""
# Create a very large schema without refs
large_schema = {
"type": "object",
"properties": {}
}
large_schema = {"type": "object", "properties": {}}

# Add many nested properties
current = large_schema
for i in range(100):
current["properties"][f"level_{i}"] = {
"type": "object",
"properties": {}
}
current["properties"][f"level_{i}"] = {"type": "object", "properties": {}}
current = current["properties"][f"level_{i}"]

# _has_dify_refs should be fast even for large schemas
times = []
for _ in range(50):
@@ -690,13 +614,13 @@ class TestSchemaResolverClass:
has_refs = _has_dify_refs(large_schema)
elapsed = time.perf_counter() - start
times.append(elapsed)
avg_time = sum(times) / len(times)
# Should be False and fast
assert not has_refs
assert avg_time < 0.01 # Should complete in less than 10ms
def test_hybrid_vs_recursive_performance(self):
"""Test performance comparison between hybrid and recursive detection"""
# Create test schemas of different types and sizes
@@ -704,16 +628,9 @@ class TestSchemaResolverClass:
# Case 1: Small schema without refs (most common case)
{
"name": "small_no_refs",
"schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"value": {"type": "number"}
}
},
"expected": False
"schema": {"type": "object", "properties": {"name": {"type": "string"}, "value": {"type": "number"}}},
"expected": False,
},
# Case 2: Medium schema without refs
{
"name": "medium_no_refs",
@@ -725,28 +642,16 @@ class TestSchemaResolverClass:
"properties": {
"name": {"type": "string"},
"value": {"type": "number"},
"items": {
"type": "array",
"items": {"type": "string"}
}
}
"items": {"type": "array", "items": {"type": "string"}},
},
}
for i in range(20)
}
},
},
"expected": False
"expected": False,
},
# Case 3: Large schema without refs
{
"name": "large_no_refs",
"schema": {
"type": "object",
"properties": {}
},
"expected": False
},
{"name": "large_no_refs", "schema": {"type": "object", "properties": {}}, "expected": False},
# Case 4: Schema with Dify refs
{
"name": "with_dify_refs",
@@ -754,45 +659,38 @@ class TestSchemaResolverClass:
"type": "object",
"properties": {
"file": {"$ref": "https://dify.ai/schemas/v1/file.json"},
"data": {"type": "string"}
}
"data": {"type": "string"},
},
},
"expected": True
"expected": True,
},
# Case 5: Schema with non-Dify refs
{
"name": "with_external_refs",
"schema": {
"type": "object",
"properties": {
"external": {"$ref": "https://example.com/schema.json"},
"data": {"type": "string"}
}
"type": "object",
"properties": {"external": {"$ref": "https://example.com/schema.json"}, "data": {"type": "string"}},
},
"expected": False
}
"expected": False,
},
]
# Add deep nesting to large schema
current = test_cases[2]["schema"]
for i in range(50):
current["properties"][f"level_{i}"] = {
"type": "object",
"properties": {}
}
current["properties"][f"level_{i}"] = {"type": "object", "properties": {}}
current = current["properties"][f"level_{i}"]

# Performance comparison
for test_case in test_cases:
schema = test_case["schema"]
expected = test_case["expected"]
name = test_case["name"]
# Test correctness first
assert _has_dify_refs_hybrid(schema) == expected
assert _has_dify_refs_recursive(schema) == expected
# Measure hybrid performance
hybrid_times = []
for _ in range(10):
@@ -800,7 +698,7 @@ class TestSchemaResolverClass:
result_hybrid = _has_dify_refs_hybrid(schema)
elapsed = time.perf_counter() - start
hybrid_times.append(elapsed)
# Measure recursive performance
recursive_times = []
for _ in range(10):
@@ -808,69 +706,62 @@ class TestSchemaResolverClass:
result_recursive = _has_dify_refs_recursive(schema)
elapsed = time.perf_counter() - start
recursive_times.append(elapsed)
avg_hybrid = sum(hybrid_times) / len(hybrid_times)
avg_recursive = sum(recursive_times) / len(recursive_times)
print(f"{name}: hybrid={avg_hybrid:.6f}s, recursive={avg_recursive:.6f}s")
# Results should be identical
assert result_hybrid == result_recursive == expected
# For schemas without refs, hybrid should be competitive or better
if not expected: # No refs case
# Hybrid might be slightly slower due to JSON serialization overhead,
# but should not be dramatically worse
assert avg_hybrid < avg_recursive * 5 # At most 5x slower
def test_string_matching_edge_cases(self):
"""Test edge cases for string-based detection"""
# Case 1: False positive potential - $ref in description
schema_false_positive = {
"type": "object",
"properties": {
"description": {
"type": "string",
"description": "This field explains how $ref works in JSON Schema"
}
}
"description": {"type": "string", "description": "This field explains how $ref works in JSON Schema"}
},
}

# Both methods should return False
assert not _has_dify_refs_hybrid(schema_false_positive)
assert not _has_dify_refs_recursive(schema_false_positive)
# Case 2: Complex URL patterns
complex_schema = {
"type": "object",
"properties": {
"config": {
"type": "object",
"type": "object",
"properties": {
"dify_url": {
"type": "string",
"default": "https://dify.ai/schemas/info"
},
"actual_ref": {
"$ref": "https://dify.ai/schemas/v1/file.json"
}
}
"dify_url": {"type": "string", "default": "https://dify.ai/schemas/info"},
"actual_ref": {"$ref": "https://dify.ai/schemas/v1/file.json"},
},
}
}
},
}

# Both methods should return True (due to actual_ref)
assert _has_dify_refs_hybrid(complex_schema)
assert _has_dify_refs_recursive(complex_schema)
# Case 3: Non-JSON serializable objects (should fall back to recursive)
import datetime

non_serializable = {
"type": "object",
"timestamp": datetime.datetime.now(),
"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}
"data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
}
# Hybrid should fall back to recursive and still work
assert _has_dify_refs_hybrid(non_serializable)
assert _has_dify_refs_recursive(non_serializable)
assert _has_dify_refs_recursive(non_serializable)

Ładowanie…
Anuluj
Zapisz