Ver código fonte

r2

tags/2.0.0-beta.1
jyong 5 meses atrás
pai
commit
b82b26bba5
36 arquivos alterados com 1983 adições e 331 exclusões
  1. 115
    21
      api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
  2. 1
    1
      api/core/app/app_config/entities.py
  3. 0
    0
      api/core/app/apps/pipeline/__init__.py
  4. 95
    0
      api/core/app/apps/pipeline/generate_response_converter.py
  5. 63
    0
      api/core/app/apps/pipeline/pipeline_config_manager.py
  6. 496
    0
      api/core/app/apps/pipeline/pipeline_generator.py
  7. 44
    0
      api/core/app/apps/pipeline/pipeline_queue_manager.py
  8. 154
    0
      api/core/app/apps/pipeline/pipeline_runner.py
  9. 35
    0
      api/core/app/entities/app_invoke_entities.py
  10. 11
    54
      api/core/datasource/__base/datasource_plugin.py
  11. 14
    31
      api/core/datasource/__base/datasource_provider.py
  12. 3
    3
      api/core/datasource/entities/api_entities.py
  13. 108
    6
      api/core/datasource/entities/datasource_entities.py
  14. 37
    0
      api/core/datasource/local_file/local_file_plugin.py
  15. 58
    0
      api/core/datasource/local_file/local_file_provider.py
  16. 80
    0
      api/core/datasource/online_document/online_document_plugin.py
  17. 50
    0
      api/core/datasource/online_document/online_document_provider.py
  18. 63
    0
      api/core/datasource/website_crawl/website_crawl_plugin.py
  19. 50
    0
      api/core/datasource/website_crawl/website_crawl_provider.py
  20. 1
    0
      api/core/plugin/entities/plugin_daemon.py
  21. 126
    27
      api/core/plugin/impl/datasource.py
  22. 4
    0
      api/core/workflow/enums.py
  23. 40
    7
      api/core/workflow/nodes/datasource/datasource_node.py
  24. 0
    5
      api/core/workflow/nodes/knowledge_index/entities.py
  25. 23
    80
      api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
  26. 0
    66
      api/core/workflow/nodes/knowledge_index/template_prompts.py
  27. 0
    2
      api/core/workflow/nodes/knowledge_retrieval/entities.py
  28. 1
    1
      api/factories/variable_factory.py
  29. 113
    0
      api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py
  30. 2
    0
      api/models/dataset.py
  31. 1
    0
      api/models/model.py
  32. 3
    3
      api/models/workflow.py
  33. 1
    1
      api/services/dataset_service.py
  34. 109
    0
      api/services/rag_pipeline/pipeline_generate_service.py
  35. 15
    15
      api/services/rag_pipeline/pipeline_template/database/database_retrieval.py
  36. 67
    8
      api/services/rag_pipeline/rag_pipeline.py

+ 115
- 21
api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py Ver arquivo

@@ -39,9 +39,9 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models.account import Account
from models.dataset import Pipeline
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
@@ -170,7 +170,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
args = parser.parse_args()

try:
response = AppGenerateService.generate_single_iteration(
response = PipelineGenerateService.generate_single_iteration(
pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True
)

@@ -207,7 +207,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
args = parser.parse_args()

try:
response = AppGenerateService.generate_single_loop(
response = PipelineGenerateService.generate_single_loop(
pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True
)

@@ -241,11 +241,12 @@ class DraftRagPipelineRunApi(Resource):

parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info", type=list, required=True, location="json")
args = parser.parse_args()

try:
response = AppGenerateService.generate(
response = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
@@ -258,7 +259,73 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)


class PublishedRagPipelineRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline):
"""
Run published workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
raise Forbidden()

parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info", type=list, required=True, location="json")
args = parser.parse_args()

try:
response = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.PUBLISHED,
streaming=True,
)

return helper.compact_generate_response(response)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)


class RagPipelineDatasourceNodeRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()

if not isinstance(current_user, Account):
raise Forbidden()

parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()

inputs = args.get("inputs")

rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user
)

return result


class RagPipelinePublishedNodeRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -283,7 +350,7 @@ class RagPipelineDatasourceNodeRunApi(Resource):
raise ValueError("missing inputs")

rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.run_datasource_workflow_node(
workflow_node_execution = rag_pipeline_service.run_published_workflow_node(
pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user
)

@@ -354,7 +421,8 @@ class PublishedRagPipelineApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()

if not pipeline.is_published:
return None
# fetch published workflow by pipeline
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline)
@@ -397,10 +465,8 @@ class PublishedRagPipelineApi(Resource):
marked_name=args.marked_name or "",
marked_comment=args.marked_comment or "",
)
pipeline.is_published = True
pipeline.workflow_id = workflow.id
db.session.commit()

workflow_created_at = TimestampField().format(workflow.created_at)

session.commit()
@@ -617,7 +683,7 @@ class RagPipelineByIdApi(Resource):
return None, 204


class RagPipelineSecondStepApi(Resource):
class PublishedRagPipelineSecondStepApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -632,9 +698,28 @@ class RagPipelineSecondStepApi(Resource):
node_id = request.args.get("node_id", required=True, type=str)

rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(
pipeline=pipeline, node_id=node_id
)
variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id)
return {
"variables": variables,
}


class DraftRagPipelineSecondStepApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def get(self, pipeline: Pipeline):
"""
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
node_id = request.args.get("node_id", required=True, type=str)

rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id)
return {
"variables": variables,
}
@@ -732,15 +817,21 @@ api.add_resource(
RagPipelineDraftNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
)
# api.add_resource(
# RagPipelinePublishedNodeRunApi,
# "/rag/pipelines/<uuid:pipeline_id>/workflows/published/nodes/<string:node_id>/run",
# )
api.add_resource(
RagPipelineDatasourceNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/datasource/nodes/<string:node_id>/run",
)

api.add_resource(
RagPipelineDraftRunIterationNodeApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
)

api.add_resource(
RagPipelinePublishedNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/nodes/<string:node_id>/run",
)

api.add_resource(
RagPipelineDraftRunLoopNodeApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run",
@@ -762,7 +853,6 @@ api.add_resource(
DefaultRagPipelineBlockConfigApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>",
)

api.add_resource(
RagPipelineByIdApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
@@ -784,6 +874,10 @@ api.add_resource(
"/rag/pipelines/datasource-plugins",
)
api.add_resource(
RagPipelineSecondStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/processing/paramters",
PublishedRagPipelineSecondStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/paramters",
)
api.add_resource(
DraftRagPipelineSecondStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/paramters",
)

+ 1
- 1
api/core/app/app_config/entities.py Ver arquivo

@@ -283,7 +283,7 @@ class AppConfig(BaseModel):
tenant_id: str
app_id: str
app_mode: AppMode
additional_features: AppAdditionalFeatures
additional_features: Optional[AppAdditionalFeatures] = None
variables: list[VariableEntity] = []
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None


+ 0
- 0
api/core/app/apps/pipeline/__init__.py Ver arquivo


+ 95
- 0
api/core/app/apps/pipeline/generate_response_converter.py Ver arquivo

@@ -0,0 +1,95 @@
from collections.abc import Generator
from typing import cast

from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)


class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse

@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.to_dict())

@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
return cls.convert_blocking_full_response(blocking_response)

@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response

if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue

response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}

if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk

@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response

if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue

response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}

if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk

+ 63
- 0
api/core/app/apps/pipeline/pipeline_config_manager.py Ver arquivo

@@ -0,0 +1,63 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.dataset import Pipeline
from models.model import AppMode
from models.workflow import Workflow


class PipelineConfig(WorkflowUIBasedAppConfig):
"""
Pipeline Config Entity.
"""

pass


class PipelineConfigManager(BaseAppConfigManager):
@classmethod
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig:
pipeline_config = PipelineConfig(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
)

return pipeline_config

@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for pipeline config

:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: only validate the structure of the config
"""
related_config_keys = []

# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)

# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)

# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)

related_config_keys = list(set(related_config_keys))

# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}

return filtered_config

+ 496
- 0
api/core/app/apps/pipeline/pipeline_generator.py Ver arquivo

@@ -0,0 +1,496 @@
import contextvars
import datetime
import json
import logging
import random
import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload

from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker

import contexts
from configs import dify_config
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from extensions.ext_database import db
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, Pipeline
from services.dataset_service import DocumentService

logger = logging.getLogger(__name__)


class PipelineGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Generator[Mapping | str, None, None]: ...

@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any]: ...

@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...

def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
workflow=workflow,
)

inputs: Mapping[str, Any] = args["inputs"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))

for datasource_info in datasource_info_list:
workflow_run_id = str(uuid.uuid4())
document_id = None
if invoke_from == InvokeFrom.PUBLISHED:
position = DocumentService.get_documents_position(pipeline.dataset_id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=pipeline.dataset_id,
built_in_field_enabled=pipeline.dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=pipeline.dataset.doc_form,
)
db.session.add(document)
db.session.commit()
document_id = document.id
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
pipline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=pipeline.dataset_id,
batch=batch,
document_id=document_id,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=pipeline_config.variables,
tenant_id=pipeline.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=[],
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
workflow_run_id=workflow_run_id,
)

contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())

# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)

workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)

return self._generate(
pipeline=pipeline,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)

def _generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.

:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = PipelineQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=pipeline.mode,
)

# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": contextvars.copy_context(),
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)

worker_thread.start()

# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
)

return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

def single_iteration_generate(
self,
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.

:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")

if args.get("inputs") is None:
raise ValueError("inputs is required")

# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)

# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())

# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)

workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)

return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)

def single_loop_generate(
self,
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.

:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")

if args.get("inputs") is None:
raise ValueError("inputs is required")

# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)

# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())

# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)

workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)

return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)

def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
for var, val in context.items():
var.set(val)
with flask_app.app_context():
try:
# workflow app
runner = PipelineRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
)

runner.run()
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()

def _handle_response(
self,
application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
)

try:
return generate_task_pipeline.process()
except ValueError as e:
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
)
raise e

def _build_document(
self,
tenant_id: str,
dataset_id: str,
built_in_field_enabled: bool,
datasource_type: str,
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Account,
batch: str,
document_form: str,
):
if datasource_type == "local_file":
name = datasource_info["name"]
elif datasource_type == "online_document":
name = datasource_info["page_title"]
elif datasource_type == "website_crawl":
name = datasource_info["title"]
else:
raise ValueError(f"Unsupported datasource type: {datasource_type}")

document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type=datasource_type,
data_source_info=json.dumps(datasource_info),
batch=batch,
name=name,
created_from=created_from,
created_by=account.id,
doc_form=document_form,
)
doc_metadata = {}
if built_in_field_enabled:
doc_metadata = {
BuiltInField.document_name: name,
BuiltInField.uploader: account.name,
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.source: datasource_type,
}
if doc_metadata:
document.doc_metadata = doc_metadata
return document

+ 44
- 0
api/core/app/apps/pipeline/pipeline_queue_manager.py Ver arquivo

@@ -0,0 +1,44 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)


class PipelineQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from)

self._app_mode = app_mode

def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)

self._q.put(message)

if isinstance(
event,
QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()

if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedError()

+ 154
- 0
api/core/app/apps/pipeline/pipeline_runner.py Ver arquivo

@@ -0,0 +1,154 @@
import logging
from typing import Optional, cast

from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Pipeline
from models.enums import UserFrom
from models.model import EndUser
from models.workflow import Workflow, WorkflowType

logger = logging.getLogger(__name__)


class PipelineRunner(WorkflowBasedAppRunner):
"""
Pipeline Application Runner
"""

def __init__(
self,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id

def run(self) -> None:
"""
Run application
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)

user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id

pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")

workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")

db.session.close()

workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())

# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files

# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.APP_ID: app_config.app_id,
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
SystemVariableKey.BATCH: self.application_generate_entity.batch,
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
}

variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)

# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)

# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id,
)

generator = workflow_entry.run(callbacks=workflow_callbacks)

for event in generator:
self._handle_event(workflow_entry, event)

def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
)
.first()
)

# return workflow
return workflow

+ 35
- 0
api/core/app/entities/app_invoke_entities.py Ver arquivo

@@ -21,6 +21,7 @@ class InvokeFrom(Enum):
WEB_APP = "web-app"
EXPLORE = "explore"
DEBUGGER = "debugger"
PUBLISHED = "published"

@classmethod
def value_of(cls, value: str):
@@ -226,3 +227,37 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
inputs: dict

single_loop_run: Optional[SingleLoopRunEntity] = None


class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
"""
RAG Pipeline Application Generate Entity.
"""

# app config
pipline_config: WorkflowUIBasedAppConfig
datasource_type: str
datasource_info: Mapping[str, Any]
dataset_id: str
batch: str
document_id: str

class SingleIterationRunEntity(BaseModel):
"""
Single Iteration Run Entity.
"""

node_id: str
inputs: dict

single_iteration_run: Optional[SingleIterationRunEntity] = None

class SingleLoopRunEntity(BaseModel):
"""
Single Loop Run Entity.
"""

node_id: str
inputs: dict

single_loop_run: Optional[SingleLoopRunEntity] = None

+ 11
- 54
api/core/datasource/__base/datasource_plugin.py Ver arquivo

@@ -1,18 +1,13 @@
from collections.abc import Mapping
from typing import Any
from abc import ABC, abstractmethod

from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format


class DatasourcePlugin:
tenant_id: str
icon: str
plugin_unique_identifier: str
class DatasourcePlugin(ABC):
entity: DatasourceEntity
runtime: DatasourceRuntime

@@ -20,57 +15,19 @@ class DatasourcePlugin:
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
self.entity = entity
self.runtime = runtime
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier

def _invoke_first_step(
self,
user_id: str,
datasource_parameters: dict[str, Any],
) -> Mapping[str, Any]:
manager = PluginDatasourceManager()

datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)

return manager.invoke_first_step(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
)

def _invoke_second_step(
self,
user_id: str,
datasource_parameters: dict[str, Any],
) -> Mapping[str, Any]:
manager = PluginDatasourceManager()

datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)

return manager.invoke_second_step(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
)
@abstractmethod
def datasource_provider_type(self) -> DatasourceProviderType:
"""
returns the type of the datasource provider
"""
return DatasourceProviderType.LOCAL_FILE

def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourcePlugin(
entity=self.entity,
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

+ 14
- 31
api/core/datasource/__base/datasource_provider.py Ver arquivo

@@ -1,26 +1,19 @@
from abc import ABC, abstractmethod
from typing import Any

from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.entities.provider_entities import ProviderConfig
from core.plugin.impl.tool import PluginToolManager
from core.tools.errors import ToolProviderCredentialValidationError


class DatasourcePluginProviderController:
class DatasourcePluginProviderController(ABC):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str

def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None:
self.entity = entity
self.tenant_id = tenant_id
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier

@property
def need_credentials(self) -> bool:
@@ -44,29 +37,19 @@ class DatasourcePluginProviderController:
):
raise ToolProviderCredentialValidationError("Invalid credentials")

def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE

@abstractmethod
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)

if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")

return DatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
pass

def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore
"""

+ 3
- 3
api/core/datasource/entities/api_entities.py Ver arquivo

@@ -28,13 +28,13 @@ class DatasourceProviderApiEntity(BaseModel):
description: I18nObject
icon: str | dict
label: I18nObject # label
type: ToolProviderType
type: str
masked_credentials: Optional[dict] = None
original_credentials: Optional[dict] = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)


+ 108
- 6
api/core/datasource/entities/datasource_entities.py Ver arquivo

@@ -23,7 +23,7 @@ class DatasourceProviderType(enum.StrEnum):

ONLINE_DOCUMENT = "online_document"
LOCAL_FILE = "local_file"
WEBSITE = "website"
WEBSITE_CRAWL = "website_crawl"

@classmethod
def value_of(cls, value: str) -> "DatasourceProviderType":
@@ -111,10 +111,10 @@ class DatasourceParameter(PluginParameter):


class DatasourceIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
provider: str = Field(..., description="The provider of the tool")
author: str = Field(..., description="The author of the datasource")
name: str = Field(..., description="The name of the datasource")
label: I18nObject = Field(..., description="The label of the datasource")
provider: str = Field(..., description="The provider of the datasource")
icon: Optional[str] = None


@@ -145,7 +145,7 @@ class DatasourceProviderEntity(ToolProviderEntity):


class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
datasources: list[DatasourceEntity] = Field(default_factory=list)
datasources: list[DatasourceEntity] = Field(default_factory=list)


class DatasourceInvokeMeta(BaseModel):
@@ -195,3 +195,105 @@ class DatasourceInvokeFrom(Enum):
"""

RAG_PIPELINE = "rag_pipeline"


class GetOnlineDocumentPagesRequest(BaseModel):
"""
Get online document pages request
"""

tenant_id: str = Field(..., description="The tenant id")


class OnlineDocumentPageIcon(BaseModel):
"""
Online document page icon
"""

type: str = Field(..., description="The type of the icon")
url: str = Field(..., description="The url of the icon")


class OnlineDocumentPage(BaseModel):
"""
Online document page
"""

page_id: str = Field(..., description="The page id")
page_title: str = Field(..., description="The page title")
page_icon: Optional[OnlineDocumentPageIcon] = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")


class OnlineDocumentInfo(BaseModel):
"""
Online document info
"""

workspace_id: str = Field(..., description="The workspace id")
workspace_name: str = Field(..., description="The workspace name")
workspace_icon: str = Field(..., description="The workspace icon")
total: int = Field(..., description="The total number of documents")
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")


class GetOnlineDocumentPagesResponse(BaseModel):
"""
Get online document pages response
"""

result: list[OnlineDocumentInfo]


class GetOnlineDocumentPageContentRequest(BaseModel):
"""
Get online document page content request
"""

online_document_info_list: list[OnlineDocumentInfo]


class OnlineDocumentPageContent(BaseModel):
"""
Online document page content
"""

page_id: str = Field(..., description="The page id")
content: str = Field(..., description="The content of the page")


class GetOnlineDocumentPageContentResponse(BaseModel):
"""
Get online document page content response
"""

result: list[OnlineDocumentPageContent]


class GetWebsiteCrawlRequest(BaseModel):
"""
Get website crawl request
"""

url: str = Field(..., description="The url of the website")
crawl_parameters: dict = Field(..., description="The crawl parameters")


class WebSiteInfo(BaseModel):
"""
Website info
"""

source_url: str = Field(..., description="The url of the website")
markdown: str = Field(..., description="The markdown of the website")
title: str = Field(..., description="The title of the website")
description: str = Field(..., description="The description of the website")


class GetWebsiteCrawlResponse(BaseModel):
"""
Get website crawl response
"""

result: list[WebSiteInfo]

+ 37
- 0
api/core/datasource/local_file/local_file_plugin.py Ver arquivo

@@ -0,0 +1,37 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)


class LocalFileDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str

def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier

def datasource_provider_type(self) -> DatasourceProviderType:
return DatasourceProviderType.LOCAL_FILE

def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourcePlugin(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

+ 58
- 0
api/core/datasource/local_file/local_file_provider.py Ver arquivo

@@ -0,0 +1,58 @@
from typing import Any

from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin


class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str

def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity)
self.tenant_id = tenant_id
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier

@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE

def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
pass

def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)

if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")

return LocalFileDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

+ 80
- 0
api/core/datasource/online_document/online_document_plugin.py Ver arquivo

@@ -0,0 +1,80 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse,
GetOnlineDocumentPagesRequest,
GetOnlineDocumentPagesResponse,
)
from core.plugin.impl.datasource import PluginDatasourceManager


class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime

def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier

def _get_online_document_pages(
self,
user_id: str,
datasource_parameters: GetOnlineDocumentPagesRequest,
provider_type: str,
) -> GetOnlineDocumentPagesResponse:
manager = PluginDatasourceManager()

return manager.get_online_document_pages(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)

def _get_online_document_page_content(
self,
user_id: str,
datasource_parameters: GetOnlineDocumentPageContentRequest,
provider_type: str,
) -> GetOnlineDocumentPageContentResponse:
manager = PluginDatasourceManager()

return manager.get_online_document_page_content(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)

def datasource_provider_type(self) -> DatasourceProviderType:
return DatasourceProviderType.ONLINE_DOCUMENT

def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourcePlugin(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

+ 50
- 0
api/core/datasource/online_document/online_document_provider.py Ver arquivo

@@ -0,0 +1,50 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType


class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str

def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity)
self.tenant_id = tenant_id
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier

@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.ONLINE_DOCUMENT

def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)

if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")

return DatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

+ 63
- 0
api/core/datasource/website_crawl/website_crawl_plugin.py Ver arquivo

@@ -0,0 +1,63 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
GetWebsiteCrawlRequest,
GetWebsiteCrawlResponse,
)
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format


class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime

def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier

def _get_website_crawl(
self,
user_id: str,
datasource_parameters: GetWebsiteCrawlRequest,
provider_type: str,
) -> GetWebsiteCrawlResponse:
manager = PluginDatasourceManager()

datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)

return manager.invoke_first_step(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)

def datasource_provider_type(self) -> DatasourceProviderType:
return DatasourceProviderType.WEBSITE_CRAWL

def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourcePlugin(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

+ 50
- 0
api/core/datasource/website_crawl/website_crawl_provider.py Ver arquivo

@@ -0,0 +1,50 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType


class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str

def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity)
self.tenant_id = tenant_id
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier

@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.WEBSITE_CRAWL

def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)

if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")

return DatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

+ 1
- 0
api/core/plugin/entities/plugin_daemon.py Ver arquivo

@@ -52,6 +52,7 @@ class PluginDatasourceProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str
plugin_id: str
author: str
declaration: DatasourceProviderEntityWithPlugin



+ 126
- 27
api/core/plugin/impl/datasource.py Ver arquivo

@@ -1,6 +1,14 @@
from collections.abc import Mapping
from typing import Any

from core.datasource.entities.api_entities import DatasourceProviderApiEntity
from core.datasource.entities.datasource_entities import (
GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse,
GetOnlineDocumentPagesRequest,
GetOnlineDocumentPagesResponse,
GetWebsiteCrawlRequest,
GetWebsiteCrawlResponse,
)
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
@@ -10,7 +18,7 @@ from core.plugin.impl.base import BasePluginClient


class PluginDatasourceManager(BasePluginClient):
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
def fetch_datasource_providers(self, tenant_id: str) -> list[DatasourceProviderApiEntity]:
"""
Fetch datasource providers for the given tenant.
"""
@@ -19,27 +27,27 @@ class PluginDatasourceManager(BasePluginClient):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for tool in declaration.get("tools", []):
tool["identity"]["provider"] = provider_name
for datasource in declaration.get("datasources", []):
datasource["identity"]["provider"] = provider_name

return json_response

response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/datasources",
list[PluginDatasourceProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
# response = self._request_with_plugin_daemon_response(
# "GET",
# f"plugin/{tenant_id}/management/datasources",
# list[PluginDatasourceProviderEntity],
# params={"page": 1, "page_size": 256},
# transformer=transformer,
# )

for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# for provider in response:
# provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"

# override the provider name for each tool to plugin_id/provider_name
for datasource in provider.declaration.datasources:
datasource.identity.provider = provider.declaration.identity.name
# # override the provider name for each tool to plugin_id/provider_name
# for datasource in provider.declaration.datasources:
# datasource.identity.provider = provider.declaration.identity.name

return response
return [DatasourceProviderApiEntity(**self._get_local_file_datasource_provider())]

def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
"""
@@ -71,15 +79,16 @@ class PluginDatasourceManager(BasePluginClient):

return response

def invoke_first_step(
def get_website_crawl(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
datasource_parameters: dict[str, Any],
) -> Mapping[str, Any]:
datasource_parameters: GetWebsiteCrawlRequest,
provider_type: str,
) -> GetWebsiteCrawlResponse:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
@@ -88,8 +97,8 @@ class PluginDatasourceManager(BasePluginClient):

response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/first_step",
dict,
f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_website_crawl",
GetWebsiteCrawlResponse,
data={
"user_id": user_id,
"data": {
@@ -109,15 +118,16 @@ class PluginDatasourceManager(BasePluginClient):

raise Exception("No response from plugin daemon")

def invoke_second_step(
def get_online_document_pages(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
datasource_parameters: dict[str, Any],
) -> Mapping[str, Any]:
datasource_parameters: GetOnlineDocumentPagesRequest,
provider_type: str,
) -> GetOnlineDocumentPagesResponse:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
@@ -126,8 +136,47 @@ class PluginDatasourceManager(BasePluginClient):

response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/second_step",
dict,
f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_pages",
GetOnlineDocumentPagesResponse,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
"credentials": credentials,
"datasource_parameters": datasource_parameters,
},
},
headers={
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp

raise Exception("No response from plugin daemon")

def get_online_document_page_content(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
datasource_parameters: GetOnlineDocumentPageContentRequest,
provider_type: str,
) -> GetOnlineDocumentPageContentResponse:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""

datasource_provider_id = GenericProviderID(datasource_provider)

response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_page_content",
GetOnlineDocumentPageContentResponse,
data={
"user_id": user_id,
"data": {
@@ -176,3 +225,53 @@ class PluginDatasourceManager(BasePluginClient):
return resp.result

return False

def _get_local_file_datasource_provider(self) -> dict[str, Any]:
return {
"id": "langgenius/file/file",
"author": "langgenius",
"name": "langgenius/file/file",
"plugin_id": "langgenius/file",
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
"description": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
},
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg",
"label": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
},
"type": "datasource",
"team_credentials": {},
"is_team_authorization": False,
"allow_delete": True,
"datasources": [{
"author": "langgenius",
"name": "upload_file",
"label": {
"en_US": "File",
"zh_Hans": "File",
"pt_BR": "File",
"ja_JP": "File"
},
"description": {
"en_US": "File",
"zh_Hans": "File",
"pt_BR": "File",
"ja_JP": "File."
},
"parameters": [],
"labels": [
"search"
],
"output_schema": None
}],
"labels": [
"search"
]
}

+ 4
- 0
api/core/workflow/enums.py Ver arquivo

@@ -14,3 +14,7 @@ class SystemVariableKey(StrEnum):
APP_ID = "app_id"
WORKFLOW_ID = "workflow_id"
WORKFLOW_RUN_ID = "workflow_run_id"
# RAG Pipeline
DOCUMENT_ID = "document_id"
BATCH = "batch"
DATASET_ID = "dataset_id"

+ 40
- 7
api/core/workflow/nodes/datasource/datasource_node.py Ver arquivo

@@ -3,7 +3,11 @@ from typing import Any, cast

from core.datasource.entities.datasource_entities import (
DatasourceParameter,
DatasourceProviderType,
GetWebsiteCrawlResponse,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.file import File
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
@@ -77,15 +81,44 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
for_log=True,
)

# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])

try:
# TODO: handle result
result = datasource_runtime._invoke_second_step(
user_id=self.user_id,
datasource_parameters=parameters,
)
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
result = datasource_runtime._get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=parameters,
provider_type=node_data.provider_type,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"result": result.result.model_dump(),
"datasource_type": datasource_runtime.datasource_provider_type,
},
)
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
user_id=self.user_id,
datasource_parameters=parameters,
provider_type=node_data.provider_type,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"result": result.result.model_dump(),
"datasource_type": datasource_runtime.datasource_provider_type,
},
)
else:
raise DatasourceNodeError(
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(

+ 0
- 5
api/core/workflow/nodes/knowledge_index/entities.py Ver arquivo

@@ -155,9 +155,4 @@ class KnowledgeIndexNodeData(BaseNodeData):
"""

type: str = "knowledge-index"
dataset_id: str
document_id: str
index_chunk_variable_selector: list[str]
chunk_structure: Literal["general", "parent-child"]
index_method: IndexMethod
retrieval_setting: RetrievalSetting

+ 23
- 80
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py Ver arquivo

@@ -1,25 +1,19 @@
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any, cast

from flask_login import current_user

from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables.segments import ObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document, RateLimitLog
from models.dataset import Dataset, Document
from models.workflow import WorkflowNodeExecutionStatus
from services.dataset_service import DatasetCollectionBindingService
from services.feature_service import FeatureService

from .entities import KnowledgeIndexNodeData
from .exc import (
@@ -43,8 +37,9 @@ class KnowledgeIndexNode(LLMNode):

def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeIndexNodeData, self.node_data)
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector)
variable = variable_pool.get(node_data.index_chunk_variable_selector)
if not isinstance(variable, ObjectSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@@ -57,34 +52,9 @@ class KnowledgeIndexNode(LLMNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
)
# check rate limit
if self.tenant_id:
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)

# retrieve knowledge
try:
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks)
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool)
outputs = {"result": results}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
@@ -107,54 +77,26 @@ class KnowledgeIndexNode(LLMNode):
error_type=type(e).__name__,
)

def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any]) -> Any:
dataset = Dataset.query.filter_by(id=node_data.dataset_id).first()
def _invoke_knowledge_index(
self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], variable_pool: VariablePool
) -> Any:
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id:
raise KnowledgeIndexNodeError("Dataset ID is required.")
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if not document_id:
raise KnowledgeIndexNodeError("Document ID is required.")
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
if not batch:
raise KnowledgeIndexNodeError("Batch is required.")
dataset = Dataset.query.filter_by(id=dataset_id).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.")
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")

document = Document.query.filter_by(id=node_data.document_id).first()
document = Document.query.filter_by(id=document_id).first()
if not document:
raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.")

retrieval_setting = node_data.retrieval_setting
index_method = node_data.index_method
if not dataset.indexing_technique:
if node_data.index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")

dataset.indexing_technique = index_method.indexing_technique
if index_method.indexing_technique == "high_quality":
model_manager = ModelManager()
if (
index_method.embedding_setting.embedding_model
and index_method.embedding_setting.embedding_model_provider
):
dataset_embedding_model = index_method.embedding_setting.embedding_model
dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")

dataset.retrieval_model = (
retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model
) # type: ignore
index_processor = IndexProcessorFactory(node_data.chunk_structure).init_index_processor()
index_processor.index(dataset, document, chunks)

@@ -166,6 +108,7 @@ class KnowledgeIndexNode(LLMNode):
return {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"batch": batch,
"document_id": document.id,
"document_name": document.name,
"created_at": document.created_at,

+ 0
- 66
api/core/workflow/nodes/knowledge_index/template_prompts.py Ver arquivo

@@ -1,66 +0,0 @@
METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
""" # noqa: E501

METADATA_FILTER_USER_PROMPT_1 = """
{ "input_text": "I want to know which company’s email address test@example.com is?",
"metadata_fields": ["filename", "email", "phone", "address"]
}
"""

METADATA_FILTER_ASSISTANT_PROMPT_1 = """
```json
{"metadata_map": [
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
]
}
```
"""

METADATA_FILTER_USER_PROMPT_2 = """
{"input_text": "What are the movies with a score of more than 9 in 2024?",
"metadata_fields": ["name", "year", "rating", "country"]}
"""

METADATA_FILTER_ASSISTANT_PROMPT_2 = """
```json
{"metadata_map": [
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
]}
```
"""

METADATA_FILTER_USER_PROMPT_3 = """
'{{"input_text": "{input_text}",',
'"metadata_fields": {metadata_fields}}}'
"""

METADATA_FILTER_COMPLETION_PROMPT = """
### Job Description
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
""" # noqa: E501

+ 0
- 2
api/core/workflow/nodes/knowledge_retrieval/entities.py Ver arquivo

@@ -57,8 +57,6 @@ class MultipleRetrievalConfig(BaseModel):


class ModelConfig(BaseModel):


provider: str
name: str
mode: str

+ 1
- 1
api/factories/variable_factory.py Ver arquivo

@@ -39,7 +39,6 @@ from core.variables.variables import (
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
PIPELINE_VARIABLE_NODE_ID,
)


@@ -123,6 +122,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = result.model_copy(update={"selector": selector})
return cast(Variable, result)


def build_segment(value: Any, /) -> Segment:
if value is None:
return NoneSegment()

+ 113
- 0
api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py Ver arquivo

@@ -0,0 +1,113 @@
"""add_pipeline_info_2

Revision ID: abb18a379e62
Revises: b35c3db83d09
Create Date: 2025-05-16 16:59:16.423127

"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = 'abb18a379e62'
down_revision = 'b35c3db83d09'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('component_failure_stats')
op.drop_table('reliability_data')
op.drop_table('maintenance')
op.drop_table('operational_data')
op.drop_table('component_failure')
op.drop_table('tool_providers')
op.drop_table('safety_data')
op.drop_table('incident_data')
with op.batch_alter_table('pipelines', schema=None) as batch_op:
batch_op.drop_column('mode')

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipelines', schema=None) as batch_op:
batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False))

op.create_table('incident_data',
sa.Column('IncidentID', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('IncidentDescription', sa.TEXT(), autoincrement=False, nullable=False),
sa.Column('IncidentDate', sa.DATE(), autoincrement=False, nullable=False),
sa.Column('Consequences', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('ResponseActions', sa.TEXT(), autoincrement=False, nullable=True),
sa.PrimaryKeyConstraint('IncidentID', name='incident_data_pkey')
)
op.create_table('safety_data',
sa.Column('SafetyID', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('SafetyInspectionDate', sa.DATE(), autoincrement=False, nullable=False),
sa.Column('SafetyFindings', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('SafetyIncidentDescription', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('ComplianceStatus', sa.VARCHAR(length=50), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('SafetyID', name='safety_data_pkey')
)
op.create_table('tool_providers',
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False),
sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
op.create_table('component_failure',
sa.Column('FailureID', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('Date', sa.DATE(), autoincrement=False, nullable=False),
sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
sa.Column('RepairAction', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('FailureID', name='component_failure_pkey'),
sa.UniqueConstraint('Date', 'Component', 'FailureMode', 'Cause', 'Technician', name='unique_failure_entry')
)
op.create_table('operational_data',
sa.Column('OperationID', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('CraneUsage', sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column('LoadWeight', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
sa.Column('LoadFrequency', sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column('EnvironmentalConditions', sa.TEXT(), autoincrement=False, nullable=True),
sa.PrimaryKeyConstraint('OperationID', name='operational_data_pkey')
)
op.create_table('maintenance',
sa.Column('MaintenanceID', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('MaintenanceType', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
sa.Column('MaintenanceDate', sa.DATE(), autoincrement=False, nullable=False),
sa.Column('ServiceDescription', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('PartsReplaced', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('MaintenanceID', name='maintenance_pkey')
)
op.create_table('reliability_data',
sa.Column('ComponentID', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('ComponentName', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
sa.Column('FailureRate', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('ComponentID', name='reliability_data_pkey')
)
op.create_table('component_failure_stats',
sa.Column('StatID', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
sa.Column('PossibleAction', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('Probability', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('StatID', name='component_failure_stats_pkey')
)
# ### end Alembic commands ###

+ 2
- 0
api/models/dataset.py Ver arquivo

@@ -1170,6 +1170,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
def pipeline(self):
return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()


class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_customized_templates"
__table_args__ = (
@@ -1205,6 +1206,7 @@ class Pipeline(Base): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()

+ 1
- 0
api/models/model.py Ver arquivo

@@ -52,6 +52,7 @@ class AppMode(StrEnum):
ADVANCED_CHAT = "advanced-chat"
AGENT_CHAT = "agent-chat"
CHANNEL = "channel"
RAG_PIPELINE = "rag-pipeline"

@classmethod
def value_of(cls, value: str) -> "AppMode":

+ 3
- 3
api/models/workflow.py Ver arquivo

@@ -3,7 +3,7 @@ import logging
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, List, Optional, Self, Union
from typing import TYPE_CHECKING, Any, Optional, Self, Union
from uuid import uuid4

from core.variables import utils as variable_utils
@@ -43,7 +43,7 @@ class WorkflowType(Enum):

WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag_pipeline"
RAG_PIPELINE = "rag-pipeline"

@classmethod
def value_of(cls, value: str) -> "WorkflowType":
@@ -370,7 +370,7 @@ class Workflow(Base):
return results

@rag_pipeline_variables.setter
def rag_pipeline_variables(self, values: List[dict]) -> None:
def rag_pipeline_variables(self, values: list[dict]) -> None:
self._rag_pipeline_variables = json.dumps(
{item["variable"]: item for item in values},
ensure_ascii=False,

+ 1
- 1
api/services/dataset_service.py Ver arquivo

@@ -1550,7 +1550,7 @@ class DocumentService:
@staticmethod
def build_document(
dataset: Dataset,
process_rule_id: str,
process_rule_id: str | None,
data_source_type: str,
document_form: str,
document_language: str,

+ 109
- 0
api/services/rag_pipeline/pipeline_generate_service.py Ver arquivo

@@ -0,0 +1,109 @@
from collections.abc import Mapping
from typing import Any, Union

from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from models.dataset import Pipeline
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from services.rag_pipeline.rag_pipeline import RagPipelineService


class PipelineGenerateService:
@classmethod
def generate(
cls,
pipeline: Pipeline,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
"""
Pipeline Content Generate
:param pipeline: pipeline
:param user: user
:param args: args
:param invoke_from: invoke from
:param streaming: streaming
:return:
"""
try:
workflow = cls._get_workflow(pipeline, invoke_from)
return PipelineGenerator.convert_to_event_stream(
PipelineGenerator().generate(
pipeline=pipeline,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
call_depth=0,
workflow_thread_pool_id=None,
),
)

except Exception:
raise

@staticmethod
def _get_max_active_requests(app_model: App) -> int:
max_active_requests = app_model.max_active_requests
if max_active_requests is None:
max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
return max_active_requests

@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")

@classmethod
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)

@classmethod
def _get_workflow(cls, pipeline: Pipeline, invoke_from: InvokeFrom) -> Workflow:
"""
Get workflow
:param pipeline: pipeline
:param invoke_from: invoke from
:return:
"""
rag_pipeline_service = RagPipelineService()
if invoke_from == InvokeFrom.DEBUGGER:
# fetch draft workflow by app_model
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)

if not workflow:
raise ValueError("Workflow not initialized")
else:
# fetch published workflow by app_model
workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline)

if not workflow:
raise ValueError("Workflow not published")

return workflow

+ 15
- 15
api/services/rag_pipeline/pipeline_template/database/database_retrieval.py Ver arquivo

@@ -29,32 +29,31 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param language: language
:return:
"""
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = db.session.query(PipelineBuiltInTemplate).filter(
PipelineBuiltInTemplate.language == language
).all()
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
)

recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates:
pipeline_model: Pipeline = pipeline_built_in_template.pipeline

recommended_pipeline_result = {
'id': pipeline_built_in_template.id,
'name': pipeline_built_in_template.name,
'pipeline_id': pipeline_model.id,
'description': pipeline_built_in_template.description,
'icon': pipeline_built_in_template.icon,
'copyright': pipeline_built_in_template.copyright,
'privacy_policy': pipeline_built_in_template.privacy_policy,
'position': pipeline_built_in_template.position,
"id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name,
"pipeline_id": pipeline_model.id,
"description": pipeline_built_in_template.description,
"icon": pipeline_built_in_template.icon,
"copyright": pipeline_built_in_template.copyright,
"privacy_policy": pipeline_built_in_template.privacy_policy,
"position": pipeline_built_in_template.position,
}
dataset: Dataset = pipeline_model.dataset
if dataset:
recommended_pipeline_result['chunk_structure'] = dataset.chunk_structure
recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
recommended_pipelines_results.append(recommended_pipeline_result)

return {'pipeline_templates': recommended_pipelines_results}

return {"pipeline_templates": recommended_pipelines_results}

@classmethod
def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]:
@@ -64,6 +63,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:return:
"""
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService

# is in public recommended list
pipeline_template = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()

+ 67
- 8
api/services/rag_pipeline/rag_pipeline.py Ver arquivo

@@ -3,7 +3,7 @@ import threading
import time
from collections.abc import Callable, Generator, Sequence
from datetime import UTC, datetime
from typing import Any, Literal, Optional
from typing import Any, Optional
from uuid import uuid4

from flask_login import current_user
@@ -46,7 +46,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi
class RagPipelineService:
@staticmethod
def get_pipeline_templates(
type: Literal["built-in", "customized"] = "built-in", language: str = "en-US"
type: str = "built-in", language: str = "en-US"
) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
@@ -358,11 +358,11 @@ class RagPipelineService:

return workflow_node_execution

def run_datasource_workflow_node(
def run_published_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
) -> WorkflowNodeExecution:
"""
Run published workflow datasource
Run published workflow node
"""
# fetch published workflow by app_model
published_workflow = self.get_published_workflow(pipeline=pipeline)
@@ -393,6 +393,41 @@ class RagPipelineService:

return workflow_node_execution

def run_datasource_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
) -> WorkflowNodeExecution:
"""
Run published workflow datasource
"""
# fetch published workflow by app_model
published_workflow = self.get_published_workflow(pipeline=pipeline)
if not published_workflow:
raise ValueError("Workflow not initialized")

# run draft workflow node
start_at = time.perf_counter()

datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {})
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager

datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=datasource_node_data.get("provider_id"),
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
)
result = datasource_runtime._invoke_first_step(
inputs=user_inputs,
provider_type=datasource_node_data.get("provider_type"),
user_id=account.id,
)

return {
"result": result,
"provider_type": datasource_node_data.get("provider_type"),
}

def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
@@ -552,7 +587,7 @@ class RagPipelineService:

return workflow

def get_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
"""
Get second step parameters of rag pipeline
"""
@@ -567,9 +602,33 @@ class RagPipelineService:
return {}

# get datasource provider
datasource_provider_variables = [item for item in rag_pipeline_variables
if item.get("belong_to_node_id") == node_id
or item.get("belong_to_node_id") == "shared"]
datasource_provider_variables = [
item
for item in rag_pipeline_variables
if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
]
return datasource_provider_variables

def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
"""
Get second step parameters of rag pipeline
"""

workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")

# get second step node
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
return {}

# get datasource provider
datasource_provider_variables = [
item
for item in rag_pipeline_variables
if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
]
return datasource_provider_variables

def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:

Carregando…
Cancelar
Salvar