| @@ -39,9 +39,9 @@ from libs.helper import TimestampField, uuid_value | |||
| from libs.login import current_user, login_required | |||
| from models.account import Account | |||
| from models.dataset import Pipeline | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.app import WorkflowHashNotEqualError | |||
| from services.errors.llm import InvokeRateLimitError | |||
| from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService | |||
| from services.rag_pipeline.rag_pipeline import RagPipelineService | |||
| from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService | |||
| from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError | |||
| @@ -170,7 +170,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource): | |||
| args = parser.parse_args() | |||
| try: | |||
| response = AppGenerateService.generate_single_iteration( | |||
| response = PipelineGenerateService.generate_single_iteration( | |||
| pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True | |||
| ) | |||
| @@ -207,7 +207,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource): | |||
| args = parser.parse_args() | |||
| try: | |||
| response = AppGenerateService.generate_single_loop( | |||
| response = PipelineGenerateService.generate_single_loop( | |||
| pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True | |||
| ) | |||
| @@ -241,11 +241,12 @@ class DraftRagPipelineRunApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("datasource_type", type=str, required=True, location="json") | |||
| parser.add_argument("datasource_info", type=list, required=True, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| response = PipelineGenerateService.generate( | |||
| pipeline=pipeline, | |||
| user=current_user, | |||
| args=args, | |||
| @@ -258,7 +259,73 @@ class DraftRagPipelineRunApi(Resource): | |||
| raise InvokeRateLimitHttpError(ex.description) | |||
| class PublishedRagPipelineRunApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_rag_pipeline | |||
| def post(self, pipeline: Pipeline): | |||
| """ | |||
| Run published workflow | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("datasource_type", type=str, required=True, location="json") | |||
| parser.add_argument("datasource_info", type=list, required=True, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| response = PipelineGenerateService.generate( | |||
| pipeline=pipeline, | |||
| user=current_user, | |||
| args=args, | |||
| invoke_from=InvokeFrom.PUBLISHED, | |||
| streaming=True, | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| except InvokeRateLimitError as ex: | |||
| raise InvokeRateLimitHttpError(ex.description) | |||
| class RagPipelineDatasourceNodeRunApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_rag_pipeline | |||
| def post(self, pipeline: Pipeline, node_id: str): | |||
| """ | |||
| Run rag pipeline datasource | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| inputs = args.get("inputs") | |||
| rag_pipeline_service = RagPipelineService() | |||
| result = rag_pipeline_service.run_datasource_workflow_node( | |||
| pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user | |||
| ) | |||
| return result | |||
| class RagPipelinePublishedNodeRunApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -283,7 +350,7 @@ class RagPipelineDatasourceNodeRunApi(Resource): | |||
| raise ValueError("missing inputs") | |||
| rag_pipeline_service = RagPipelineService() | |||
| workflow_node_execution = rag_pipeline_service.run_datasource_workflow_node( | |||
| workflow_node_execution = rag_pipeline_service.run_published_workflow_node( | |||
| pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user | |||
| ) | |||
| @@ -354,7 +421,8 @@ class PublishedRagPipelineApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not pipeline.is_published: | |||
| return None | |||
| # fetch published workflow by pipeline | |||
| rag_pipeline_service = RagPipelineService() | |||
| workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline) | |||
| @@ -397,10 +465,8 @@ class PublishedRagPipelineApi(Resource): | |||
| marked_name=args.marked_name or "", | |||
| marked_comment=args.marked_comment or "", | |||
| ) | |||
| pipeline.is_published = True | |||
| pipeline.workflow_id = workflow.id | |||
| db.session.commit() | |||
| workflow_created_at = TimestampField().format(workflow.created_at) | |||
| session.commit() | |||
| @@ -617,7 +683,7 @@ class RagPipelineByIdApi(Resource): | |||
| return None, 204 | |||
| class RagPipelineSecondStepApi(Resource): | |||
| class PublishedRagPipelineSecondStepApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -632,9 +698,28 @@ class RagPipelineSecondStepApi(Resource): | |||
| node_id = request.args.get("node_id", required=True, type=str) | |||
| rag_pipeline_service = RagPipelineService() | |||
| variables = rag_pipeline_service.get_second_step_parameters( | |||
| pipeline=pipeline, node_id=node_id | |||
| ) | |||
| variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id) | |||
| return { | |||
| "variables": variables, | |||
| } | |||
| class DraftRagPipelineSecondStepApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_rag_pipeline | |||
| def get(self, pipeline: Pipeline): | |||
| """ | |||
| Get second step parameters of rag pipeline | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| node_id = request.args.get("node_id", required=True, type=str) | |||
| rag_pipeline_service = RagPipelineService() | |||
| variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id) | |||
| return { | |||
| "variables": variables, | |||
| } | |||
| @@ -732,15 +817,21 @@ api.add_resource( | |||
| RagPipelineDraftNodeRunApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run", | |||
| ) | |||
| # api.add_resource( | |||
| # RagPipelinePublishedNodeRunApi, | |||
| # "/rag/pipelines/<uuid:pipeline_id>/workflows/published/nodes/<string:node_id>/run", | |||
| # ) | |||
| api.add_resource( | |||
| RagPipelineDatasourceNodeRunApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/datasource/nodes/<string:node_id>/run", | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineDraftRunIterationNodeApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run", | |||
| ) | |||
| api.add_resource( | |||
| RagPipelinePublishedNodeRunApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/published/nodes/<string:node_id>/run", | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineDraftRunLoopNodeApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run", | |||
| @@ -762,7 +853,6 @@ api.add_resource( | |||
| DefaultRagPipelineBlockConfigApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>", | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineByIdApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>", | |||
| @@ -784,6 +874,10 @@ api.add_resource( | |||
| "/rag/pipelines/datasource-plugins", | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineSecondStepApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/processing/paramters", | |||
| PublishedRagPipelineSecondStepApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/paramters", | |||
| ) | |||
| api.add_resource( | |||
| DraftRagPipelineSecondStepApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/paramters", | |||
| ) | |||
| @@ -283,7 +283,7 @@ class AppConfig(BaseModel): | |||
| tenant_id: str | |||
| app_id: str | |||
| app_mode: AppMode | |||
| additional_features: AppAdditionalFeatures | |||
| additional_features: Optional[AppAdditionalFeatures] = None | |||
| variables: list[VariableEntity] = [] | |||
| sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None | |||
| @@ -0,0 +1,95 @@ | |||
| from collections.abc import Generator | |||
| from typing import cast | |||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | |||
| from core.app.entities.task_entities import ( | |||
| AppStreamResponse, | |||
| ErrorStreamResponse, | |||
| NodeFinishStreamResponse, | |||
| NodeStartStreamResponse, | |||
| PingStreamResponse, | |||
| WorkflowAppBlockingResponse, | |||
| WorkflowAppStreamResponse, | |||
| ) | |||
| class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| _blocking_response_type = WorkflowAppBlockingResponse | |||
| @classmethod | |||
| def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] | |||
| """ | |||
| Convert blocking full response. | |||
| :param blocking_response: blocking response | |||
| :return: | |||
| """ | |||
| return dict(blocking_response.to_dict()) | |||
| @classmethod | |||
| def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] | |||
| """ | |||
| Convert blocking simple response. | |||
| :param blocking_response: blocking response | |||
| :return: | |||
| """ | |||
| return cls.convert_blocking_full_response(blocking_response) | |||
| @classmethod | |||
| def convert_stream_full_response( | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream full response. | |||
| :param stream_response: stream response | |||
| :return: | |||
| """ | |||
| for chunk in stream_response: | |||
| chunk = cast(WorkflowAppStreamResponse, chunk) | |||
| sub_stream_response = chunk.stream_response | |||
| if isinstance(sub_stream_response, PingStreamResponse): | |||
| yield "ping" | |||
| continue | |||
| response_chunk = { | |||
| "event": sub_stream_response.event.value, | |||
| "workflow_run_id": chunk.workflow_run_id, | |||
| } | |||
| if isinstance(sub_stream_response, ErrorStreamResponse): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield response_chunk | |||
| @classmethod | |||
| def convert_stream_simple_response( | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream simple response. | |||
| :param stream_response: stream response | |||
| :return: | |||
| """ | |||
| for chunk in stream_response: | |||
| chunk = cast(WorkflowAppStreamResponse, chunk) | |||
| sub_stream_response = chunk.stream_response | |||
| if isinstance(sub_stream_response, PingStreamResponse): | |||
| yield "ping" | |||
| continue | |||
| response_chunk = { | |||
| "event": sub_stream_response.event.value, | |||
| "workflow_run_id": chunk.workflow_run_id, | |||
| } | |||
| if isinstance(sub_stream_response, ErrorStreamResponse): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): | |||
| response_chunk.update(sub_stream_response.to_ignore_detail_dict()) | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield response_chunk | |||
| @@ -0,0 +1,63 @@ | |||
| from core.app.app_config.base_app_config_manager import BaseAppConfigManager | |||
| from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager | |||
| from core.app.app_config.entities import WorkflowUIBasedAppConfig | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager | |||
| from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager | |||
| from models.dataset import Pipeline | |||
| from models.model import AppMode | |||
| from models.workflow import Workflow | |||
| class PipelineConfig(WorkflowUIBasedAppConfig): | |||
| """ | |||
| Pipeline Config Entity. | |||
| """ | |||
| pass | |||
| class PipelineConfigManager(BaseAppConfigManager): | |||
| @classmethod | |||
| def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig: | |||
| pipeline_config = PipelineConfig( | |||
| tenant_id=pipeline.tenant_id, | |||
| app_id=pipeline.id, | |||
| app_mode=AppMode.RAG_PIPELINE, | |||
| workflow_id=workflow.id, | |||
| variables=WorkflowVariablesConfigManager.convert(workflow=workflow), | |||
| ) | |||
| return pipeline_config | |||
| @classmethod | |||
| def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: | |||
| """ | |||
| Validate for pipeline config | |||
| :param tenant_id: tenant id | |||
| :param config: app model config args | |||
| :param only_structure_validate: only validate the structure of the config | |||
| """ | |||
| related_config_keys = [] | |||
| # file upload validation | |||
| config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) | |||
| related_config_keys.extend(current_related_config_keys) | |||
| # text_to_speech | |||
| config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) | |||
| related_config_keys.extend(current_related_config_keys) | |||
| # moderation validation | |||
| config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( | |||
| tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate | |||
| ) | |||
| related_config_keys.extend(current_related_config_keys) | |||
| related_config_keys = list(set(related_config_keys)) | |||
| # Filter out extra parameters | |||
| filtered_config = {key: config.get(key) for key in related_config_keys} | |||
| return filtered_config | |||
| @@ -0,0 +1,496 @@ | |||
| import contextvars | |||
| import datetime | |||
| import json | |||
| import logging | |||
| import random | |||
| import threading | |||
| import time | |||
| import uuid | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Any, Literal, Optional, Union, overload | |||
| from flask import Flask, current_app | |||
| from pydantic import ValidationError | |||
| from sqlalchemy.orm import sessionmaker | |||
| import contexts | |||
| from configs import dify_config | |||
| from core.app.apps.base_app_generator import BaseAppGenerator | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom | |||
| from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager | |||
| from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager | |||
| from core.app.apps.pipeline.pipeline_runner import PipelineRunner | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | |||
| from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter | |||
| from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity | |||
| from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.rag.index_processor.constant.built_in_field import BuiltInField | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||
| from extensions.ext_database import db | |||
| from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | |||
| from models.dataset import Document, Pipeline | |||
| from services.dataset_service import DocumentService | |||
| logger = logging.getLogger(__name__) | |||
| class PipelineGenerator(BaseAppGenerator): | |||
| @overload | |||
| def generate( | |||
| self, | |||
| *, | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[True], | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Generator[Mapping | str, None, None]: ... | |||
| @overload | |||
| def generate( | |||
| self, | |||
| *, | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: Literal[False], | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Mapping[str, Any]: ... | |||
| @overload | |||
| def generate( | |||
| self, | |||
| *, | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool, | |||
| call_depth: int, | |||
| workflow_thread_pool_id: Optional[str], | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... | |||
| def generate( | |||
| self, | |||
| *, | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| call_depth: int = 0, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: | |||
| # convert to app config | |||
| pipeline_config = PipelineConfigManager.get_pipeline_config( | |||
| pipeline=pipeline, | |||
| workflow=workflow, | |||
| ) | |||
| inputs: Mapping[str, Any] = args["inputs"] | |||
| datasource_type: str = args["datasource_type"] | |||
| datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] | |||
| batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) | |||
| for datasource_info in datasource_info_list: | |||
| workflow_run_id = str(uuid.uuid4()) | |||
| document_id = None | |||
| if invoke_from == InvokeFrom.PUBLISHED: | |||
| position = DocumentService.get_documents_position(pipeline.dataset_id) | |||
| document = self._build_document( | |||
| tenant_id=pipeline.tenant_id, | |||
| dataset_id=pipeline.dataset_id, | |||
| built_in_field_enabled=pipeline.dataset.built_in_field_enabled, | |||
| datasource_type=datasource_type, | |||
| datasource_info=datasource_info, | |||
| created_from="rag-pipeline", | |||
| position=position, | |||
| account=user, | |||
| batch=batch, | |||
| document_form=pipeline.dataset.doc_form, | |||
| ) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| document_id = document.id | |||
| # init application generate entity | |||
| application_generate_entity = RagPipelineGenerateEntity( | |||
| task_id=str(uuid.uuid4()), | |||
| pipline_config=pipeline_config, | |||
| datasource_type=datasource_type, | |||
| datasource_info=datasource_info, | |||
| dataset_id=pipeline.dataset_id, | |||
| batch=batch, | |||
| document_id=document_id, | |||
| inputs=self._prepare_user_inputs( | |||
| user_inputs=inputs, | |||
| variables=pipeline_config.variables, | |||
| tenant_id=pipeline.tenant_id, | |||
| strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, | |||
| ), | |||
| files=[], | |||
| user_id=user.id, | |||
| stream=streaming, | |||
| invoke_from=invoke_from, | |||
| call_depth=call_depth, | |||
| workflow_run_id=workflow_run_id, | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| return self._generate( | |||
| pipeline=pipeline, | |||
| workflow=workflow, | |||
| user=user, | |||
| application_generate_entity=application_generate_entity, | |||
| invoke_from=invoke_from, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| ) | |||
| def _generate( | |||
| self, | |||
| *, | |||
| pipeline: Pipeline, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| application_generate_entity: RagPipelineGenerateEntity, | |||
| invoke_from: InvokeFrom, | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| streaming: bool = True, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: | |||
| """ | |||
| Generate App response. | |||
| :param app_model: App | |||
| :param workflow: Workflow | |||
| :param user: account or end user | |||
| :param application_generate_entity: application generate entity | |||
| :param invoke_from: invoke from source | |||
| :param workflow_node_execution_repository: repository for workflow node execution | |||
| :param streaming: is stream | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| # init queue manager | |||
| queue_manager = PipelineQueueManager( | |||
| task_id=application_generate_entity.task_id, | |||
| user_id=application_generate_entity.user_id, | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| app_mode=pipeline.mode, | |||
| ) | |||
| # new thread | |||
| worker_thread = threading.Thread( | |||
| target=self._generate_worker, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "application_generate_entity": application_generate_entity, | |||
| "queue_manager": queue_manager, | |||
| "context": contextvars.copy_context(), | |||
| "workflow_thread_pool_id": workflow_thread_pool_id, | |||
| }, | |||
| ) | |||
| worker_thread.start() | |||
| # return response or stream generator | |||
| response = self._handle_response( | |||
| application_generate_entity=application_generate_entity, | |||
| workflow=workflow, | |||
| queue_manager=queue_manager, | |||
| user=user, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| stream=streaming, | |||
| ) | |||
| return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) | |||
| def single_iteration_generate( | |||
| self, | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user: Account | EndUser, | |||
| args: Mapping[str, Any], | |||
| streaming: bool = True, | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: | |||
| """ | |||
| Generate App response. | |||
| :param app_model: App | |||
| :param workflow: Workflow | |||
| :param node_id: the node id | |||
| :param user: account or end user | |||
| :param args: request args | |||
| :param streaming: is streamed | |||
| """ | |||
| if not node_id: | |||
| raise ValueError("node_id is required") | |||
| if args.get("inputs") is None: | |||
| raise ValueError("inputs is required") | |||
| # convert to app config | |||
| app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) | |||
| # init application generate entity | |||
| application_generate_entity = WorkflowAppGenerateEntity( | |||
| task_id=str(uuid.uuid4()), | |||
| app_config=app_config, | |||
| inputs={}, | |||
| files=[], | |||
| user_id=user.id, | |||
| stream=streaming, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| extras={"auto_generate_conversation_name": False}, | |||
| single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( | |||
| node_id=node_id, inputs=args["inputs"] | |||
| ), | |||
| workflow_run_id=str(uuid.uuid4()), | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| return self._generate( | |||
| app_model=app_model, | |||
| workflow=workflow, | |||
| user=user, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| ) | |||
| def single_loop_generate( | |||
| self, | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user: Account | EndUser, | |||
| args: Mapping[str, Any], | |||
| streaming: bool = True, | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: | |||
| """ | |||
| Generate App response. | |||
| :param app_model: App | |||
| :param workflow: Workflow | |||
| :param node_id: the node id | |||
| :param user: account or end user | |||
| :param args: request args | |||
| :param streaming: is streamed | |||
| """ | |||
| if not node_id: | |||
| raise ValueError("node_id is required") | |||
| if args.get("inputs") is None: | |||
| raise ValueError("inputs is required") | |||
| # convert to app config | |||
| app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) | |||
| # init application generate entity | |||
| application_generate_entity = WorkflowAppGenerateEntity( | |||
| task_id=str(uuid.uuid4()), | |||
| app_config=app_config, | |||
| inputs={}, | |||
| files=[], | |||
| user_id=user.id, | |||
| stream=streaming, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| extras={"auto_generate_conversation_name": False}, | |||
| single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), | |||
| workflow_run_id=str(uuid.uuid4()), | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| user=user, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| return self._generate( | |||
| app_model=app_model, | |||
| workflow=workflow, | |||
| user=user, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| streaming=streaming, | |||
| ) | |||
| def _generate_worker( | |||
| self, | |||
| flask_app: Flask, | |||
| application_generate_entity: RagPipelineGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| context: contextvars.Context, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> None: | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: queue manager | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| :return: | |||
| """ | |||
| for var, val in context.items(): | |||
| var.set(val) | |||
| with flask_app.app_context(): | |||
| try: | |||
| # workflow app | |||
| runner = PipelineRunner( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| ) | |||
| runner.run() | |||
| except GenerateTaskStoppedError: | |||
| pass | |||
| except InvokeAuthorizationError: | |||
| queue_manager.publish_error( | |||
| InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER | |||
| ) | |||
| except ValidationError as e: | |||
| logger.exception("Validation Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| except ValueError as e: | |||
| if dify_config.DEBUG: | |||
| logger.exception("Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| except Exception as e: | |||
| logger.exception("Unknown Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| finally: | |||
| db.session.close() | |||
| def _handle_response( | |||
| self, | |||
| application_generate_entity: RagPipelineGenerateEntity, | |||
| workflow: Workflow, | |||
| queue_manager: AppQueueManager, | |||
| user: Union[Account, EndUser], | |||
| workflow_node_execution_repository: WorkflowNodeExecutionRepository, | |||
| stream: bool = False, | |||
| ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: | |||
| """ | |||
| Handle response. | |||
| :param application_generate_entity: application generate entity | |||
| :param workflow: workflow | |||
| :param queue_manager: queue manager | |||
| :param user: account or end user | |||
| :param stream: is stream | |||
| :param workflow_node_execution_repository: optional repository for workflow node execution | |||
| :return: | |||
| """ | |||
| # init generate task pipeline | |||
| generate_task_pipeline = WorkflowAppGenerateTaskPipeline( | |||
| application_generate_entity=application_generate_entity, | |||
| workflow=workflow, | |||
| queue_manager=queue_manager, | |||
| user=user, | |||
| stream=stream, | |||
| workflow_node_execution_repository=workflow_node_execution_repository, | |||
| ) | |||
| try: | |||
| return generate_task_pipeline.process() | |||
| except ValueError as e: | |||
| if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error | |||
| raise GenerateTaskStoppedError() | |||
| else: | |||
| logger.exception( | |||
| f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}" | |||
| ) | |||
| raise e | |||
| def _build_document( | |||
| self, | |||
| tenant_id: str, | |||
| dataset_id: str, | |||
| built_in_field_enabled: bool, | |||
| datasource_type: str, | |||
| datasource_info: Mapping[str, Any], | |||
| created_from: str, | |||
| position: int, | |||
| account: Account, | |||
| batch: str, | |||
| document_form: str, | |||
| ): | |||
| if datasource_type == "local_file": | |||
| name = datasource_info["name"] | |||
| elif datasource_type == "online_document": | |||
| name = datasource_info["page_title"] | |||
| elif datasource_type == "website_crawl": | |||
| name = datasource_info["title"] | |||
| else: | |||
| raise ValueError(f"Unsupported datasource type: {datasource_type}") | |||
| document = Document( | |||
| tenant_id=tenant_id, | |||
| dataset_id=dataset_id, | |||
| position=position, | |||
| data_source_type=datasource_type, | |||
| data_source_info=json.dumps(datasource_info), | |||
| batch=batch, | |||
| name=name, | |||
| created_from=created_from, | |||
| created_by=account.id, | |||
| doc_form=document_form, | |||
| ) | |||
| doc_metadata = {} | |||
| if built_in_field_enabled: | |||
| doc_metadata = { | |||
| BuiltInField.document_name: name, | |||
| BuiltInField.uploader: account.name, | |||
| BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), | |||
| BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), | |||
| BuiltInField.source: datasource_type, | |||
| } | |||
| if doc_metadata: | |||
| document.doc_metadata = doc_metadata | |||
| return document | |||
| @@ -0,0 +1,44 @@ | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.entities.queue_entities import ( | |||
| AppQueueEvent, | |||
| QueueErrorEvent, | |||
| QueueMessageEndEvent, | |||
| QueueStopEvent, | |||
| QueueWorkflowFailedEvent, | |||
| QueueWorkflowPartialSuccessEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| WorkflowQueueMessage, | |||
| ) | |||
| class PipelineQueueManager(AppQueueManager): | |||
| def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: | |||
| super().__init__(task_id, user_id, invoke_from) | |||
| self._app_mode = app_mode | |||
| def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | |||
| """ | |||
| Publish event to queue | |||
| :param event: | |||
| :param pub_from: | |||
| :return: | |||
| """ | |||
| message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event) | |||
| self._q.put(message) | |||
| if isinstance( | |||
| event, | |||
| QueueStopEvent | |||
| | QueueErrorEvent | |||
| | QueueMessageEndEvent | |||
| | QueueWorkflowSucceededEvent | |||
| | QueueWorkflowFailedEvent | |||
| | QueueWorkflowPartialSuccessEvent, | |||
| ): | |||
| self.stop_listen() | |||
| if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): | |||
| raise GenerateTaskStoppedError() | |||
| @@ -0,0 +1,154 @@ | |||
| import logging | |||
| from typing import Optional, cast | |||
| from configs import dify_config | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig | |||
| from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner | |||
| from core.app.entities.app_invoke_entities import ( | |||
| InvokeFrom, | |||
| RagPipelineGenerateEntity, | |||
| ) | |||
| from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models.dataset import Pipeline | |||
| from models.enums import UserFrom | |||
| from models.model import EndUser | |||
| from models.workflow import Workflow, WorkflowType | |||
| logger = logging.getLogger(__name__) | |||
| class PipelineRunner(WorkflowBasedAppRunner): | |||
| """ | |||
| Pipeline Application Runner | |||
| """ | |||
| def __init__( | |||
| self, | |||
| application_generate_entity: RagPipelineGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> None: | |||
| """ | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: application queue manager | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| self.application_generate_entity = application_generate_entity | |||
| self.queue_manager = queue_manager | |||
| self.workflow_thread_pool_id = workflow_thread_pool_id | |||
| def run(self) -> None: | |||
| """ | |||
| Run application | |||
| """ | |||
| app_config = self.application_generate_entity.app_config | |||
| app_config = cast(PipelineConfig, app_config) | |||
| user_id = None | |||
| if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() | |||
| if end_user: | |||
| user_id = end_user.session_id | |||
| else: | |||
| user_id = self.application_generate_entity.user_id | |||
| pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first() | |||
| if not pipeline: | |||
| raise ValueError("Pipeline not found") | |||
| workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id) | |||
| if not workflow: | |||
| raise ValueError("Workflow not initialized") | |||
| db.session.close() | |||
| workflow_callbacks: list[WorkflowCallback] = [] | |||
| if dify_config.DEBUG: | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| # if only single iteration run is requested | |||
| if self.application_generate_entity.single_iteration_run: | |||
| # if only single iteration run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | |||
| workflow=workflow, | |||
| node_id=self.application_generate_entity.single_iteration_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_iteration_run.inputs, | |||
| ) | |||
| elif self.application_generate_entity.single_loop_run: | |||
| # if only single loop run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( | |||
| workflow=workflow, | |||
| node_id=self.application_generate_entity.single_loop_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_loop_run.inputs, | |||
| ) | |||
| else: | |||
| inputs = self.application_generate_entity.inputs | |||
| files = self.application_generate_entity.files | |||
| # Create a variable pool. | |||
| system_inputs = { | |||
| SystemVariableKey.FILES: files, | |||
| SystemVariableKey.USER_ID: user_id, | |||
| SystemVariableKey.APP_ID: app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, | |||
| SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id, | |||
| SystemVariableKey.BATCH: self.application_generate_entity.batch, | |||
| SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, | |||
| } | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| user_inputs=inputs, | |||
| environment_variables=workflow.environment_variables, | |||
| conversation_variables=[], | |||
| ) | |||
| # init graph | |||
| graph = self._init_graph(graph_config=workflow.graph_dict) | |||
| # RUN WORKFLOW | |||
| workflow_entry = WorkflowEntry( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=workflow.app_id, | |||
| workflow_id=workflow.id, | |||
| workflow_type=WorkflowType.value_of(workflow.type), | |||
| graph=graph, | |||
| graph_config=workflow.graph_dict, | |||
| user_id=self.application_generate_entity.user_id, | |||
| user_from=( | |||
| UserFrom.ACCOUNT | |||
| if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | |||
| else UserFrom.END_USER | |||
| ), | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| call_depth=self.application_generate_entity.call_depth, | |||
| variable_pool=variable_pool, | |||
| thread_pool_id=self.workflow_thread_pool_id, | |||
| ) | |||
| generator = workflow_entry.run(callbacks=workflow_callbacks) | |||
| for event in generator: | |||
| self._handle_event(workflow_entry, event) | |||
| def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]: | |||
| """ | |||
| Get workflow | |||
| """ | |||
| # fetch workflow by workflow_id | |||
| workflow = ( | |||
| db.session.query(Workflow) | |||
| .filter( | |||
| Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id | |||
| ) | |||
| .first() | |||
| ) | |||
| # return workflow | |||
| return workflow | |||
| @@ -21,6 +21,7 @@ class InvokeFrom(Enum): | |||
| WEB_APP = "web-app" | |||
| EXPLORE = "explore" | |||
| DEBUGGER = "debugger" | |||
| PUBLISHED = "published" | |||
| @classmethod | |||
| def value_of(cls, value: str): | |||
| @@ -226,3 +227,37 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): | |||
| inputs: dict | |||
| single_loop_run: Optional[SingleLoopRunEntity] = None | |||
| class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): | |||
| """ | |||
| RAG Pipeline Application Generate Entity. | |||
| """ | |||
| # app config | |||
| pipline_config: WorkflowUIBasedAppConfig | |||
| datasource_type: str | |||
| datasource_info: Mapping[str, Any] | |||
| dataset_id: str | |||
| batch: str | |||
| document_id: str | |||
| class SingleIterationRunEntity(BaseModel): | |||
| """ | |||
| Single Iteration Run Entity. | |||
| """ | |||
| node_id: str | |||
| inputs: dict | |||
| single_iteration_run: Optional[SingleIterationRunEntity] = None | |||
| class SingleLoopRunEntity(BaseModel): | |||
| """ | |||
| Single Loop Run Entity. | |||
| """ | |||
| node_id: str | |||
| inputs: dict | |||
| single_loop_run: Optional[SingleLoopRunEntity] = None | |||
| @@ -1,18 +1,13 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any | |||
| from abc import ABC, abstractmethod | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import ( | |||
| DatasourceEntity, | |||
| DatasourceProviderType, | |||
| ) | |||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||
| from core.plugin.utils.converter import convert_parameters_to_plugin_format | |||
| class DatasourcePlugin: | |||
| tenant_id: str | |||
| icon: str | |||
| plugin_unique_identifier: str | |||
| class DatasourcePlugin(ABC): | |||
| entity: DatasourceEntity | |||
| runtime: DatasourceRuntime | |||
| @@ -20,57 +15,19 @@ class DatasourcePlugin: | |||
| self, | |||
| entity: DatasourceEntity, | |||
| runtime: DatasourceRuntime, | |||
| tenant_id: str, | |||
| icon: str, | |||
| plugin_unique_identifier: str, | |||
| ) -> None: | |||
| self.entity = entity | |||
| self.runtime = runtime | |||
| self.tenant_id = tenant_id | |||
| self.icon = icon | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| def _invoke_first_step( | |||
| self, | |||
| user_id: str, | |||
| datasource_parameters: dict[str, Any], | |||
| ) -> Mapping[str, Any]: | |||
| manager = PluginDatasourceManager() | |||
| datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) | |||
| return manager.invoke_first_step( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| datasource_parameters=datasource_parameters, | |||
| ) | |||
| def _invoke_second_step( | |||
| self, | |||
| user_id: str, | |||
| datasource_parameters: dict[str, Any], | |||
| ) -> Mapping[str, Any]: | |||
| manager = PluginDatasourceManager() | |||
| datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) | |||
| return manager.invoke_second_step( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| datasource_parameters=datasource_parameters, | |||
| ) | |||
| @abstractmethod | |||
| def datasource_provider_type(self) -> DatasourceProviderType: | |||
| """ | |||
| returns the type of the datasource provider | |||
| """ | |||
| return DatasourceProviderType.LOCAL_FILE | |||
| def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": | |||
| return DatasourcePlugin( | |||
| entity=self.entity, | |||
| return self.__class__( | |||
| entity=self.entity.model_copy(), | |||
| runtime=runtime, | |||
| tenant_id=self.tenant_id, | |||
| icon=self.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| @@ -1,26 +1,19 @@ | |||
| from abc import ABC, abstractmethod | |||
| from typing import Any | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType | |||
| from core.entities.provider_entities import ProviderConfig | |||
| from core.plugin.impl.tool import PluginToolManager | |||
| from core.tools.errors import ToolProviderCredentialValidationError | |||
| class DatasourcePluginProviderController: | |||
| class DatasourcePluginProviderController(ABC): | |||
| entity: DatasourceProviderEntityWithPlugin | |||
| tenant_id: str | |||
| plugin_id: str | |||
| plugin_unique_identifier: str | |||
| def __init__( | |||
| self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str | |||
| ) -> None: | |||
| def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None: | |||
| self.entity = entity | |||
| self.tenant_id = tenant_id | |||
| self.plugin_id = plugin_id | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| @property | |||
| def need_credentials(self) -> bool: | |||
| @@ -44,29 +37,19 @@ class DatasourcePluginProviderController: | |||
| ): | |||
| raise ToolProviderCredentialValidationError("Invalid credentials") | |||
| def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore | |||
| @property | |||
| def provider_type(self) -> DatasourceProviderType: | |||
| """ | |||
| returns the type of the provider | |||
| """ | |||
| return DatasourceProviderType.LOCAL_FILE | |||
| @abstractmethod | |||
| def get_datasource(self, datasource_name: str) -> DatasourcePlugin: | |||
| """ | |||
| return datasource with given name | |||
| """ | |||
| datasource_entity = next( | |||
| ( | |||
| datasource_entity | |||
| for datasource_entity in self.entity.datasources | |||
| if datasource_entity.identity.name == datasource_name | |||
| ), | |||
| None, | |||
| ) | |||
| if not datasource_entity: | |||
| raise ValueError(f"Datasource with name {datasource_name} not found") | |||
| return DatasourcePlugin( | |||
| entity=datasource_entity, | |||
| runtime=DatasourceRuntime(tenant_id=self.tenant_id), | |||
| tenant_id=self.tenant_id, | |||
| icon=self.entity.identity.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| pass | |||
| def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore | |||
| """ | |||
| @@ -28,13 +28,13 @@ class DatasourceProviderApiEntity(BaseModel): | |||
| description: I18nObject | |||
| icon: str | dict | |||
| label: I18nObject # label | |||
| type: ToolProviderType | |||
| type: str | |||
| masked_credentials: Optional[dict] = None | |||
| original_credentials: Optional[dict] = None | |||
| is_team_authorization: bool = False | |||
| allow_delete: bool = True | |||
| plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") | |||
| plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") | |||
| plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource") | |||
| plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource") | |||
| datasources: list[DatasourceApiEntity] = Field(default_factory=list) | |||
| labels: list[str] = Field(default_factory=list) | |||
| @@ -23,7 +23,7 @@ class DatasourceProviderType(enum.StrEnum): | |||
| ONLINE_DOCUMENT = "online_document" | |||
| LOCAL_FILE = "local_file" | |||
| WEBSITE = "website" | |||
| WEBSITE_CRAWL = "website_crawl" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "DatasourceProviderType": | |||
| @@ -111,10 +111,10 @@ class DatasourceParameter(PluginParameter): | |||
| class DatasourceIdentity(BaseModel): | |||
| author: str = Field(..., description="The author of the tool") | |||
| name: str = Field(..., description="The name of the tool") | |||
| label: I18nObject = Field(..., description="The label of the tool") | |||
| provider: str = Field(..., description="The provider of the tool") | |||
| author: str = Field(..., description="The author of the datasource") | |||
| name: str = Field(..., description="The name of the datasource") | |||
| label: I18nObject = Field(..., description="The label of the datasource") | |||
| provider: str = Field(..., description="The provider of the datasource") | |||
| icon: Optional[str] = None | |||
| @@ -145,7 +145,7 @@ class DatasourceProviderEntity(ToolProviderEntity): | |||
| class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): | |||
| datasources: list[DatasourceEntity] = Field(default_factory=list) | |||
| datasources: list[DatasourceEntity] = Field(default_factory=list) | |||
| class DatasourceInvokeMeta(BaseModel): | |||
| @@ -195,3 +195,105 @@ class DatasourceInvokeFrom(Enum): | |||
| """ | |||
| RAG_PIPELINE = "rag_pipeline" | |||
| class GetOnlineDocumentPagesRequest(BaseModel): | |||
| """ | |||
| Get online document pages request | |||
| """ | |||
| tenant_id: str = Field(..., description="The tenant id") | |||
| class OnlineDocumentPageIcon(BaseModel): | |||
| """ | |||
| Online document page icon | |||
| """ | |||
| type: str = Field(..., description="The type of the icon") | |||
| url: str = Field(..., description="The url of the icon") | |||
| class OnlineDocumentPage(BaseModel): | |||
| """ | |||
| Online document page | |||
| """ | |||
| page_id: str = Field(..., description="The page id") | |||
| page_title: str = Field(..., description="The page title") | |||
| page_icon: Optional[OnlineDocumentPageIcon] = Field(None, description="The page icon") | |||
| type: str = Field(..., description="The type of the page") | |||
| last_edited_time: str = Field(..., description="The last edited time") | |||
| class OnlineDocumentInfo(BaseModel): | |||
| """ | |||
| Online document info | |||
| """ | |||
| workspace_id: str = Field(..., description="The workspace id") | |||
| workspace_name: str = Field(..., description="The workspace name") | |||
| workspace_icon: str = Field(..., description="The workspace icon") | |||
| total: int = Field(..., description="The total number of documents") | |||
| pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") | |||
| class GetOnlineDocumentPagesResponse(BaseModel): | |||
| """ | |||
| Get online document pages response | |||
| """ | |||
| result: list[OnlineDocumentInfo] | |||
| class GetOnlineDocumentPageContentRequest(BaseModel): | |||
| """ | |||
| Get online document page content request | |||
| """ | |||
| online_document_info_list: list[OnlineDocumentInfo] | |||
| class OnlineDocumentPageContent(BaseModel): | |||
| """ | |||
| Online document page content | |||
| """ | |||
| page_id: str = Field(..., description="The page id") | |||
| content: str = Field(..., description="The content of the page") | |||
| class GetOnlineDocumentPageContentResponse(BaseModel): | |||
| """ | |||
| Get online document page content response | |||
| """ | |||
| result: list[OnlineDocumentPageContent] | |||
| class GetWebsiteCrawlRequest(BaseModel): | |||
| """ | |||
| Get website crawl request | |||
| """ | |||
| url: str = Field(..., description="The url of the website") | |||
| crawl_parameters: dict = Field(..., description="The crawl parameters") | |||
| class WebSiteInfo(BaseModel): | |||
| """ | |||
| Website info | |||
| """ | |||
| source_url: str = Field(..., description="The url of the website") | |||
| markdown: str = Field(..., description="The markdown of the website") | |||
| title: str = Field(..., description="The title of the website") | |||
| description: str = Field(..., description="The description of the website") | |||
| class GetWebsiteCrawlResponse(BaseModel): | |||
| """ | |||
| Get website crawl response | |||
| """ | |||
| result: list[WebSiteInfo] | |||
| @@ -0,0 +1,37 @@ | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import ( | |||
| DatasourceEntity, | |||
| DatasourceProviderType, | |||
| ) | |||
| class LocalFileDatasourcePlugin(DatasourcePlugin): | |||
| tenant_id: str | |||
| icon: str | |||
| plugin_unique_identifier: str | |||
| def __init__( | |||
| self, | |||
| entity: DatasourceEntity, | |||
| runtime: DatasourceRuntime, | |||
| tenant_id: str, | |||
| icon: str, | |||
| plugin_unique_identifier: str, | |||
| ) -> None: | |||
| super().__init__(entity, runtime) | |||
| self.tenant_id = tenant_id | |||
| self.icon = icon | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| def datasource_provider_type(self) -> DatasourceProviderType: | |||
| return DatasourceProviderType.LOCAL_FILE | |||
| def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": | |||
| return DatasourcePlugin( | |||
| entity=self.entity, | |||
| runtime=runtime, | |||
| tenant_id=self.tenant_id, | |||
| icon=self.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| @@ -0,0 +1,58 @@ | |||
| from typing import Any | |||
| from core.datasource.__base.datasource_provider import DatasourcePluginProviderController | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType | |||
| from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin | |||
| class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController): | |||
| entity: DatasourceProviderEntityWithPlugin | |||
| tenant_id: str | |||
| plugin_id: str | |||
| plugin_unique_identifier: str | |||
| def __init__( | |||
| self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str | |||
| ) -> None: | |||
| super().__init__(entity) | |||
| self.tenant_id = tenant_id | |||
| self.plugin_id = plugin_id | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| @property | |||
| def provider_type(self) -> DatasourceProviderType: | |||
| """ | |||
| returns the type of the provider | |||
| """ | |||
| return DatasourceProviderType.LOCAL_FILE | |||
| def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: | |||
| """ | |||
| validate the credentials of the provider | |||
| """ | |||
| pass | |||
| def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore | |||
| """ | |||
| return datasource with given name | |||
| """ | |||
| datasource_entity = next( | |||
| ( | |||
| datasource_entity | |||
| for datasource_entity in self.entity.datasources | |||
| if datasource_entity.identity.name == datasource_name | |||
| ), | |||
| None, | |||
| ) | |||
| if not datasource_entity: | |||
| raise ValueError(f"Datasource with name {datasource_name} not found") | |||
| return LocalFileDatasourcePlugin( | |||
| entity=datasource_entity, | |||
| runtime=DatasourceRuntime(tenant_id=self.tenant_id), | |||
| tenant_id=self.tenant_id, | |||
| icon=self.entity.identity.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| @@ -0,0 +1,80 @@ | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import ( | |||
| DatasourceEntity, | |||
| DatasourceProviderType, | |||
| GetOnlineDocumentPageContentRequest, | |||
| GetOnlineDocumentPageContentResponse, | |||
| GetOnlineDocumentPagesRequest, | |||
| GetOnlineDocumentPagesResponse, | |||
| ) | |||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||
| class OnlineDocumentDatasourcePlugin(DatasourcePlugin): | |||
| tenant_id: str | |||
| icon: str | |||
| plugin_unique_identifier: str | |||
| entity: DatasourceEntity | |||
| runtime: DatasourceRuntime | |||
| def __init__( | |||
| self, | |||
| entity: DatasourceEntity, | |||
| runtime: DatasourceRuntime, | |||
| tenant_id: str, | |||
| icon: str, | |||
| plugin_unique_identifier: str, | |||
| ) -> None: | |||
| super().__init__(entity, runtime) | |||
| self.tenant_id = tenant_id | |||
| self.icon = icon | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| def _get_online_document_pages( | |||
| self, | |||
| user_id: str, | |||
| datasource_parameters: GetOnlineDocumentPagesRequest, | |||
| provider_type: str, | |||
| ) -> GetOnlineDocumentPagesResponse: | |||
| manager = PluginDatasourceManager() | |||
| return manager.get_online_document_pages( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| datasource_parameters=datasource_parameters, | |||
| provider_type=provider_type, | |||
| ) | |||
| def _get_online_document_page_content( | |||
| self, | |||
| user_id: str, | |||
| datasource_parameters: GetOnlineDocumentPageContentRequest, | |||
| provider_type: str, | |||
| ) -> GetOnlineDocumentPageContentResponse: | |||
| manager = PluginDatasourceManager() | |||
| return manager.get_online_document_page_content( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| datasource_parameters=datasource_parameters, | |||
| provider_type=provider_type, | |||
| ) | |||
| def datasource_provider_type(self) -> DatasourceProviderType: | |||
| return DatasourceProviderType.ONLINE_DOCUMENT | |||
| def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": | |||
| return DatasourcePlugin( | |||
| entity=self.entity, | |||
| runtime=runtime, | |||
| tenant_id=self.tenant_id, | |||
| icon=self.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| @@ -0,0 +1,50 @@ | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.__base.datasource_provider import DatasourcePluginProviderController | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType | |||
| class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController): | |||
| entity: DatasourceProviderEntityWithPlugin | |||
| tenant_id: str | |||
| plugin_id: str | |||
| plugin_unique_identifier: str | |||
| def __init__( | |||
| self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str | |||
| ) -> None: | |||
| super().__init__(entity) | |||
| self.tenant_id = tenant_id | |||
| self.plugin_id = plugin_id | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| @property | |||
| def provider_type(self) -> DatasourceProviderType: | |||
| """ | |||
| returns the type of the provider | |||
| """ | |||
| return DatasourceProviderType.ONLINE_DOCUMENT | |||
| def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore | |||
| """ | |||
| return datasource with given name | |||
| """ | |||
| datasource_entity = next( | |||
| ( | |||
| datasource_entity | |||
| for datasource_entity in self.entity.datasources | |||
| if datasource_entity.identity.name == datasource_name | |||
| ), | |||
| None, | |||
| ) | |||
| if not datasource_entity: | |||
| raise ValueError(f"Datasource with name {datasource_name} not found") | |||
| return DatasourcePlugin( | |||
| entity=datasource_entity, | |||
| runtime=DatasourceRuntime(tenant_id=self.tenant_id), | |||
| tenant_id=self.tenant_id, | |||
| icon=self.entity.identity.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| @@ -0,0 +1,63 @@ | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import ( | |||
| DatasourceEntity, | |||
| DatasourceProviderType, | |||
| GetWebsiteCrawlRequest, | |||
| GetWebsiteCrawlResponse, | |||
| ) | |||
| from core.plugin.impl.datasource import PluginDatasourceManager | |||
| from core.plugin.utils.converter import convert_parameters_to_plugin_format | |||
| class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): | |||
| tenant_id: str | |||
| icon: str | |||
| plugin_unique_identifier: str | |||
| entity: DatasourceEntity | |||
| runtime: DatasourceRuntime | |||
| def __init__( | |||
| self, | |||
| entity: DatasourceEntity, | |||
| runtime: DatasourceRuntime, | |||
| tenant_id: str, | |||
| icon: str, | |||
| plugin_unique_identifier: str, | |||
| ) -> None: | |||
| super().__init__(entity, runtime) | |||
| self.tenant_id = tenant_id | |||
| self.icon = icon | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| def _get_website_crawl( | |||
| self, | |||
| user_id: str, | |||
| datasource_parameters: GetWebsiteCrawlRequest, | |||
| provider_type: str, | |||
| ) -> GetWebsiteCrawlResponse: | |||
| manager = PluginDatasourceManager() | |||
| datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) | |||
| return manager.invoke_first_step( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| datasource_parameters=datasource_parameters, | |||
| provider_type=provider_type, | |||
| ) | |||
| def datasource_provider_type(self) -> DatasourceProviderType: | |||
| return DatasourceProviderType.WEBSITE_CRAWL | |||
| def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": | |||
| return DatasourcePlugin( | |||
| entity=self.entity, | |||
| runtime=runtime, | |||
| tenant_id=self.tenant_id, | |||
| icon=self.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| @@ -0,0 +1,50 @@ | |||
| from core.datasource.__base.datasource_plugin import DatasourcePlugin | |||
| from core.datasource.__base.datasource_provider import DatasourcePluginProviderController | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType | |||
| class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController): | |||
| entity: DatasourceProviderEntityWithPlugin | |||
| tenant_id: str | |||
| plugin_id: str | |||
| plugin_unique_identifier: str | |||
| def __init__( | |||
| self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str | |||
| ) -> None: | |||
| super().__init__(entity) | |||
| self.tenant_id = tenant_id | |||
| self.plugin_id = plugin_id | |||
| self.plugin_unique_identifier = plugin_unique_identifier | |||
| @property | |||
| def provider_type(self) -> DatasourceProviderType: | |||
| """ | |||
| returns the type of the provider | |||
| """ | |||
| return DatasourceProviderType.WEBSITE_CRAWL | |||
| def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore | |||
| """ | |||
| return datasource with given name | |||
| """ | |||
| datasource_entity = next( | |||
| ( | |||
| datasource_entity | |||
| for datasource_entity in self.entity.datasources | |||
| if datasource_entity.identity.name == datasource_name | |||
| ), | |||
| None, | |||
| ) | |||
| if not datasource_entity: | |||
| raise ValueError(f"Datasource with name {datasource_name} not found") | |||
| return DatasourcePlugin( | |||
| entity=datasource_entity, | |||
| runtime=DatasourceRuntime(tenant_id=self.tenant_id), | |||
| tenant_id=self.tenant_id, | |||
| icon=self.entity.identity.icon, | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| @@ -52,6 +52,7 @@ class PluginDatasourceProviderEntity(BaseModel): | |||
| provider: str | |||
| plugin_unique_identifier: str | |||
| plugin_id: str | |||
| author: str | |||
| declaration: DatasourceProviderEntityWithPlugin | |||
| @@ -1,6 +1,14 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any | |||
| from core.datasource.entities.api_entities import DatasourceProviderApiEntity | |||
| from core.datasource.entities.datasource_entities import ( | |||
| GetOnlineDocumentPageContentRequest, | |||
| GetOnlineDocumentPageContentResponse, | |||
| GetOnlineDocumentPagesRequest, | |||
| GetOnlineDocumentPagesResponse, | |||
| GetWebsiteCrawlRequest, | |||
| GetWebsiteCrawlResponse, | |||
| ) | |||
| from core.plugin.entities.plugin import GenericProviderID, ToolProviderID | |||
| from core.plugin.entities.plugin_daemon import ( | |||
| PluginBasicBooleanResponse, | |||
| @@ -10,7 +18,7 @@ from core.plugin.impl.base import BasePluginClient | |||
| class PluginDatasourceManager(BasePluginClient): | |||
| def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: | |||
| def fetch_datasource_providers(self, tenant_id: str) -> list[DatasourceProviderApiEntity]: | |||
| """ | |||
| Fetch datasource providers for the given tenant. | |||
| """ | |||
| @@ -19,27 +27,27 @@ class PluginDatasourceManager(BasePluginClient): | |||
| for provider in json_response.get("data", []): | |||
| declaration = provider.get("declaration", {}) or {} | |||
| provider_name = declaration.get("identity", {}).get("name") | |||
| for tool in declaration.get("tools", []): | |||
| tool["identity"]["provider"] = provider_name | |||
| for datasource in declaration.get("datasources", []): | |||
| datasource["identity"]["provider"] = provider_name | |||
| return json_response | |||
| response = self._request_with_plugin_daemon_response( | |||
| "GET", | |||
| f"plugin/{tenant_id}/management/datasources", | |||
| list[PluginDatasourceProviderEntity], | |||
| params={"page": 1, "page_size": 256}, | |||
| transformer=transformer, | |||
| ) | |||
| # response = self._request_with_plugin_daemon_response( | |||
| # "GET", | |||
| # f"plugin/{tenant_id}/management/datasources", | |||
| # list[PluginDatasourceProviderEntity], | |||
| # params={"page": 1, "page_size": 256}, | |||
| # transformer=transformer, | |||
| # ) | |||
| for provider in response: | |||
| provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" | |||
| # for provider in response: | |||
| # provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" | |||
| # override the provider name for each tool to plugin_id/provider_name | |||
| for datasource in provider.declaration.datasources: | |||
| datasource.identity.provider = provider.declaration.identity.name | |||
| # # override the provider name for each tool to plugin_id/provider_name | |||
| # for datasource in provider.declaration.datasources: | |||
| # datasource.identity.provider = provider.declaration.identity.name | |||
| return response | |||
| return [DatasourceProviderApiEntity(**self._get_local_file_datasource_provider())] | |||
| def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: | |||
| """ | |||
| @@ -71,15 +79,16 @@ class PluginDatasourceManager(BasePluginClient): | |||
| return response | |||
| def invoke_first_step( | |||
| def get_website_crawl( | |||
| self, | |||
| tenant_id: str, | |||
| user_id: str, | |||
| datasource_provider: str, | |||
| datasource_name: str, | |||
| credentials: dict[str, Any], | |||
| datasource_parameters: dict[str, Any], | |||
| ) -> Mapping[str, Any]: | |||
| datasource_parameters: GetWebsiteCrawlRequest, | |||
| provider_type: str, | |||
| ) -> GetWebsiteCrawlResponse: | |||
| """ | |||
| Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. | |||
| """ | |||
| @@ -88,8 +97,8 @@ class PluginDatasourceManager(BasePluginClient): | |||
| response = self._request_with_plugin_daemon_response_stream( | |||
| "POST", | |||
| f"plugin/{tenant_id}/dispatch/datasource/first_step", | |||
| dict, | |||
| f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_website_crawl", | |||
| GetWebsiteCrawlResponse, | |||
| data={ | |||
| "user_id": user_id, | |||
| "data": { | |||
| @@ -109,15 +118,16 @@ class PluginDatasourceManager(BasePluginClient): | |||
| raise Exception("No response from plugin daemon") | |||
| def invoke_second_step( | |||
| def get_online_document_pages( | |||
| self, | |||
| tenant_id: str, | |||
| user_id: str, | |||
| datasource_provider: str, | |||
| datasource_name: str, | |||
| credentials: dict[str, Any], | |||
| datasource_parameters: dict[str, Any], | |||
| ) -> Mapping[str, Any]: | |||
| datasource_parameters: GetOnlineDocumentPagesRequest, | |||
| provider_type: str, | |||
| ) -> GetOnlineDocumentPagesResponse: | |||
| """ | |||
| Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. | |||
| """ | |||
| @@ -126,8 +136,47 @@ class PluginDatasourceManager(BasePluginClient): | |||
| response = self._request_with_plugin_daemon_response_stream( | |||
| "POST", | |||
| f"plugin/{tenant_id}/dispatch/datasource/second_step", | |||
| dict, | |||
| f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_pages", | |||
| GetOnlineDocumentPagesResponse, | |||
| data={ | |||
| "user_id": user_id, | |||
| "data": { | |||
| "provider": datasource_provider_id.provider_name, | |||
| "datasource": datasource_name, | |||
| "credentials": credentials, | |||
| "datasource_parameters": datasource_parameters, | |||
| }, | |||
| }, | |||
| headers={ | |||
| "X-Plugin-ID": datasource_provider_id.plugin_id, | |||
| "Content-Type": "application/json", | |||
| }, | |||
| ) | |||
| for resp in response: | |||
| return resp | |||
| raise Exception("No response from plugin daemon") | |||
| def get_online_document_page_content( | |||
| self, | |||
| tenant_id: str, | |||
| user_id: str, | |||
| datasource_provider: str, | |||
| datasource_name: str, | |||
| credentials: dict[str, Any], | |||
| datasource_parameters: GetOnlineDocumentPageContentRequest, | |||
| provider_type: str, | |||
| ) -> GetOnlineDocumentPageContentResponse: | |||
| """ | |||
| Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. | |||
| """ | |||
| datasource_provider_id = GenericProviderID(datasource_provider) | |||
| response = self._request_with_plugin_daemon_response_stream( | |||
| "POST", | |||
| f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_page_content", | |||
| GetOnlineDocumentPageContentResponse, | |||
| data={ | |||
| "user_id": user_id, | |||
| "data": { | |||
| @@ -176,3 +225,53 @@ class PluginDatasourceManager(BasePluginClient): | |||
| return resp.result | |||
| return False | |||
| def _get_local_file_datasource_provider(self) -> dict[str, Any]: | |||
| return { | |||
| "id": "langgenius/file/file", | |||
| "author": "langgenius", | |||
| "name": "langgenius/file/file", | |||
| "plugin_id": "langgenius/file", | |||
| "plugin_unique_identifier": "langgenius/file:0.0.1@dify", | |||
| "description": { | |||
| "zh_Hans": "File", | |||
| "en_US": "File", | |||
| "pt_BR": "File", | |||
| "ja_JP": "File" | |||
| }, | |||
| "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", | |||
| "label": { | |||
| "zh_Hans": "File", | |||
| "en_US": "File", | |||
| "pt_BR": "File", | |||
| "ja_JP": "File" | |||
| }, | |||
| "type": "datasource", | |||
| "team_credentials": {}, | |||
| "is_team_authorization": False, | |||
| "allow_delete": True, | |||
| "datasources": [{ | |||
| "author": "langgenius", | |||
| "name": "upload_file", | |||
| "label": { | |||
| "en_US": "File", | |||
| "zh_Hans": "File", | |||
| "pt_BR": "File", | |||
| "ja_JP": "File" | |||
| }, | |||
| "description": { | |||
| "en_US": "File", | |||
| "zh_Hans": "File", | |||
| "pt_BR": "File", | |||
| "ja_JP": "File." | |||
| }, | |||
| "parameters": [], | |||
| "labels": [ | |||
| "search" | |||
| ], | |||
| "output_schema": None | |||
| }], | |||
| "labels": [ | |||
| "search" | |||
| ] | |||
| } | |||
| @@ -14,3 +14,7 @@ class SystemVariableKey(StrEnum): | |||
| APP_ID = "app_id" | |||
| WORKFLOW_ID = "workflow_id" | |||
| WORKFLOW_RUN_ID = "workflow_run_id" | |||
| # RAG Pipeline | |||
| DOCUMENT_ID = "document_id" | |||
| BATCH = "batch" | |||
| DATASET_ID = "dataset_id" | |||
| @@ -3,7 +3,11 @@ from typing import Any, cast | |||
| from core.datasource.entities.datasource_entities import ( | |||
| DatasourceParameter, | |||
| DatasourceProviderType, | |||
| GetWebsiteCrawlResponse, | |||
| ) | |||
| from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin | |||
| from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin | |||
| from core.file import File | |||
| from core.plugin.impl.exc import PluginDaemonClientSideError | |||
| from core.variables.segments import ArrayAnySegment | |||
| @@ -77,15 +81,44 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| for_log=True, | |||
| ) | |||
| # get conversation id | |||
| conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) | |||
| try: | |||
| # TODO: handle result | |||
| result = datasource_runtime._invoke_second_step( | |||
| user_id=self.user_id, | |||
| datasource_parameters=parameters, | |||
| ) | |||
| if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: | |||
| datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) | |||
| result = datasource_runtime._get_online_document_page_content( | |||
| user_id=self.user_id, | |||
| datasource_parameters=parameters, | |||
| provider_type=node_data.provider_type, | |||
| ) | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=parameters_for_log, | |||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| outputs={ | |||
| "result": result.result.model_dump(), | |||
| "datasource_type": datasource_runtime.datasource_provider_type, | |||
| }, | |||
| ) | |||
| elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: | |||
| datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) | |||
| result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( | |||
| user_id=self.user_id, | |||
| datasource_parameters=parameters, | |||
| provider_type=node_data.provider_type, | |||
| ) | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=parameters_for_log, | |||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| outputs={ | |||
| "result": result.result.model_dump(), | |||
| "datasource_type": datasource_runtime.datasource_provider_type, | |||
| }, | |||
| ) | |||
| else: | |||
| raise DatasourceNodeError( | |||
| f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}" | |||
| ) | |||
| except PluginDaemonClientSideError as e: | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| @@ -155,9 +155,4 @@ class KnowledgeIndexNodeData(BaseNodeData): | |||
| """ | |||
| type: str = "knowledge-index" | |||
| dataset_id: str | |||
| document_id: str | |||
| index_chunk_variable_selector: list[str] | |||
| chunk_structure: Literal["general", "parent-child"] | |||
| index_method: IndexMethod | |||
| retrieval_setting: RetrievalSetting | |||
| @@ -1,25 +1,19 @@ | |||
| import datetime | |||
| import logging | |||
| import time | |||
| from collections.abc import Mapping | |||
| from typing import Any, cast | |||
| from flask_login import current_user | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.variables.segments import ObjectSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.llm.node import LLMNode | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset, Document, RateLimitLog | |||
| from models.dataset import Dataset, Document | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from services.dataset_service import DatasetCollectionBindingService | |||
| from services.feature_service import FeatureService | |||
| from .entities import KnowledgeIndexNodeData | |||
| from .exc import ( | |||
| @@ -43,8 +37,9 @@ class KnowledgeIndexNode(LLMNode): | |||
| def _run(self) -> NodeRunResult: # type: ignore | |||
| node_data = cast(KnowledgeIndexNodeData, self.node_data) | |||
| variable_pool = self.graph_runtime_state.variable_pool | |||
| # extract variables | |||
| variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector) | |||
| variable = variable_pool.get(node_data.index_chunk_variable_selector) | |||
| if not isinstance(variable, ObjectSegment): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| @@ -57,34 +52,9 @@ class KnowledgeIndexNode(LLMNode): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." | |||
| ) | |||
| # check rate limit | |||
| if self.tenant_id: | |||
| knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) | |||
| if knowledge_rate_limit.enabled: | |||
| current_time = int(time.time() * 1000) | |||
| key = f"rate_limit_{self.tenant_id}" | |||
| redis_client.zadd(key, {current_time: current_time}) | |||
| redis_client.zremrangebyscore(key, 0, current_time - 60000) | |||
| request_count = redis_client.zcard(key) | |||
| if request_count > knowledge_rate_limit.limit: | |||
| # add ratelimit record | |||
| rate_limit_log = RateLimitLog( | |||
| tenant_id=self.tenant_id, | |||
| subscription_plan=knowledge_rate_limit.subscription_plan, | |||
| operation="knowledge", | |||
| ) | |||
| db.session.add(rate_limit_log) | |||
| db.session.commit() | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=variables, | |||
| error="Sorry, you have reached the knowledge base request rate limit of your subscription.", | |||
| error_type="RateLimitExceeded", | |||
| ) | |||
| # retrieve knowledge | |||
| try: | |||
| results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks) | |||
| results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool) | |||
| outputs = {"result": results} | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs | |||
| @@ -107,54 +77,26 @@ class KnowledgeIndexNode(LLMNode): | |||
| error_type=type(e).__name__, | |||
| ) | |||
| def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any]) -> Any: | |||
| dataset = Dataset.query.filter_by(id=node_data.dataset_id).first() | |||
| def _invoke_knowledge_index( | |||
| self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], variable_pool: VariablePool | |||
| ) -> Any: | |||
| dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) | |||
| if not dataset_id: | |||
| raise KnowledgeIndexNodeError("Dataset ID is required.") | |||
| document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) | |||
| if not document_id: | |||
| raise KnowledgeIndexNodeError("Document ID is required.") | |||
| batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) | |||
| if not batch: | |||
| raise KnowledgeIndexNodeError("Batch is required.") | |||
| dataset = Dataset.query.filter_by(id=dataset_id).first() | |||
| if not dataset: | |||
| raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.") | |||
| raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") | |||
| document = Document.query.filter_by(id=node_data.document_id).first() | |||
| document = Document.query.filter_by(id=document_id).first() | |||
| if not document: | |||
| raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.") | |||
| retrieval_setting = node_data.retrieval_setting | |||
| index_method = node_data.index_method | |||
| if not dataset.indexing_technique: | |||
| if node_data.index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: | |||
| raise ValueError("Indexing technique is invalid") | |||
| dataset.indexing_technique = index_method.indexing_technique | |||
| if index_method.indexing_technique == "high_quality": | |||
| model_manager = ModelManager() | |||
| if ( | |||
| index_method.embedding_setting.embedding_model | |||
| and index_method.embedding_setting.embedding_model_provider | |||
| ): | |||
| dataset_embedding_model = index_method.embedding_setting.embedding_model | |||
| dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider | |||
| else: | |||
| embedding_model = model_manager.get_default_model_instance( | |||
| tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING | |||
| ) | |||
| dataset_embedding_model = embedding_model.model | |||
| dataset_embedding_model_provider = embedding_model.provider | |||
| dataset.embedding_model = dataset_embedding_model | |||
| dataset.embedding_model_provider = dataset_embedding_model_provider | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||
| dataset_embedding_model_provider, dataset_embedding_model | |||
| ) | |||
| dataset.collection_binding_id = dataset_collection_binding.id | |||
| if not dataset.retrieval_model: | |||
| default_retrieval_model = { | |||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| "reranking_enable": False, | |||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||
| "top_k": 2, | |||
| "score_threshold_enabled": False, | |||
| } | |||
| raise KnowledgeIndexNodeError(f"Document {document_id} not found.") | |||
| dataset.retrieval_model = ( | |||
| retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model | |||
| ) # type: ignore | |||
| index_processor = IndexProcessorFactory(node_data.chunk_structure).init_index_processor() | |||
| index_processor.index(dataset, document, chunks) | |||
| @@ -166,6 +108,7 @@ class KnowledgeIndexNode(LLMNode): | |||
| return { | |||
| "dataset_id": dataset.id, | |||
| "dataset_name": dataset.name, | |||
| "batch": batch, | |||
| "document_id": document.id, | |||
| "document_name": document.name, | |||
| "created_at": document.created_at, | |||
| @@ -1,66 +0,0 @@ | |||
| METADATA_FILTER_SYSTEM_PROMPT = """ | |||
| ### Job Description', | |||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||
| ### Task | |||
| Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||
| ### Format | |||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||
| ### Constraint | |||
| DO NOT include anything other than the JSON array in your response. | |||
| """ # noqa: E501 | |||
| METADATA_FILTER_USER_PROMPT_1 = """ | |||
| { "input_text": "I want to know which company’s email address test@example.com is?", | |||
| "metadata_fields": ["filename", "email", "phone", "address"] | |||
| } | |||
| """ | |||
| METADATA_FILTER_ASSISTANT_PROMPT_1 = """ | |||
| ```json | |||
| {"metadata_map": [ | |||
| {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="} | |||
| ] | |||
| } | |||
| ``` | |||
| """ | |||
| METADATA_FILTER_USER_PROMPT_2 = """ | |||
| {"input_text": "What are the movies with a score of more than 9 in 2024?", | |||
| "metadata_fields": ["name", "year", "rating", "country"]} | |||
| """ | |||
| METADATA_FILTER_ASSISTANT_PROMPT_2 = """ | |||
| ```json | |||
| {"metadata_map": [ | |||
| {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, | |||
| {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}, | |||
| ]} | |||
| ``` | |||
| """ | |||
| METADATA_FILTER_USER_PROMPT_3 = """ | |||
| '{{"input_text": "{input_text}",', | |||
| '"metadata_fields": {metadata_fields}}}' | |||
| """ | |||
| METADATA_FILTER_COMPLETION_PROMPT = """ | |||
| ### Job Description | |||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||
| ### Task | |||
| # Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||
| ### Format | |||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||
| ### Constraint | |||
| DO NOT include anything other than the JSON array in your response. | |||
| ### Example | |||
| Here is the chat example between human and assistant, inside <example></example> XML tags. | |||
| <example> | |||
| User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}} | |||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}} | |||
| User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}} | |||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}} | |||
| </example> | |||
| ### User Input | |||
| {{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}} | |||
| ### Assistant Output | |||
| """ # noqa: E501 | |||
| @@ -57,8 +57,6 @@ class MultipleRetrievalConfig(BaseModel): | |||
| class ModelConfig(BaseModel): | |||
| provider: str | |||
| name: str | |||
| mode: str | |||
| @@ -39,7 +39,6 @@ from core.variables.variables import ( | |||
| from core.workflow.constants import ( | |||
| CONVERSATION_VARIABLE_NODE_ID, | |||
| ENVIRONMENT_VARIABLE_NODE_ID, | |||
| PIPELINE_VARIABLE_NODE_ID, | |||
| ) | |||
| @@ -123,6 +122,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen | |||
| result = result.model_copy(update={"selector": selector}) | |||
| return cast(Variable, result) | |||
| def build_segment(value: Any, /) -> Segment: | |||
| if value is None: | |||
| return NoneSegment() | |||
| @@ -0,0 +1,113 @@ | |||
| """add_pipeline_info_2 | |||
| Revision ID: abb18a379e62 | |||
| Revises: b35c3db83d09 | |||
| Create Date: 2025-05-16 16:59:16.423127 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'abb18a379e62' | |||
| down_revision = 'b35c3db83d09' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.drop_table('component_failure_stats') | |||
| op.drop_table('reliability_data') | |||
| op.drop_table('maintenance') | |||
| op.drop_table('operational_data') | |||
| op.drop_table('component_failure') | |||
| op.drop_table('tool_providers') | |||
| op.drop_table('safety_data') | |||
| op.drop_table('incident_data') | |||
| with op.batch_alter_table('pipelines', schema=None) as batch_op: | |||
| batch_op.drop_column('mode') | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('pipelines', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False)) | |||
| op.create_table('incident_data', | |||
| sa.Column('IncidentID', sa.INTEGER(), autoincrement=True, nullable=False), | |||
| sa.Column('IncidentDescription', sa.TEXT(), autoincrement=False, nullable=False), | |||
| sa.Column('IncidentDate', sa.DATE(), autoincrement=False, nullable=False), | |||
| sa.Column('Consequences', sa.TEXT(), autoincrement=False, nullable=True), | |||
| sa.Column('ResponseActions', sa.TEXT(), autoincrement=False, nullable=True), | |||
| sa.PrimaryKeyConstraint('IncidentID', name='incident_data_pkey') | |||
| ) | |||
| op.create_table('safety_data', | |||
| sa.Column('SafetyID', sa.INTEGER(), autoincrement=True, nullable=False), | |||
| sa.Column('SafetyInspectionDate', sa.DATE(), autoincrement=False, nullable=False), | |||
| sa.Column('SafetyFindings', sa.TEXT(), autoincrement=False, nullable=True), | |||
| sa.Column('SafetyIncidentDescription', sa.TEXT(), autoincrement=False, nullable=True), | |||
| sa.Column('ComplianceStatus', sa.VARCHAR(length=50), autoincrement=False, nullable=False), | |||
| sa.PrimaryKeyConstraint('SafetyID', name='safety_data_pkey') | |||
| ) | |||
| op.create_table('tool_providers', | |||
| sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), | |||
| sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False), | |||
| sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False), | |||
| sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True), | |||
| sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), | |||
| sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), | |||
| sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), | |||
| sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), | |||
| sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') | |||
| ) | |||
| op.create_table('component_failure', | |||
| sa.Column('FailureID', sa.INTEGER(), autoincrement=True, nullable=False), | |||
| sa.Column('Date', sa.DATE(), autoincrement=False, nullable=False), | |||
| sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False), | |||
| sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False), | |||
| sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False), | |||
| sa.Column('RepairAction', sa.TEXT(), autoincrement=False, nullable=True), | |||
| sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False), | |||
| sa.PrimaryKeyConstraint('FailureID', name='component_failure_pkey'), | |||
| sa.UniqueConstraint('Date', 'Component', 'FailureMode', 'Cause', 'Technician', name='unique_failure_entry') | |||
| ) | |||
| op.create_table('operational_data', | |||
| sa.Column('OperationID', sa.INTEGER(), autoincrement=True, nullable=False), | |||
| sa.Column('CraneUsage', sa.INTEGER(), autoincrement=False, nullable=False), | |||
| sa.Column('LoadWeight', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), | |||
| sa.Column('LoadFrequency', sa.INTEGER(), autoincrement=False, nullable=False), | |||
| sa.Column('EnvironmentalConditions', sa.TEXT(), autoincrement=False, nullable=True), | |||
| sa.PrimaryKeyConstraint('OperationID', name='operational_data_pkey') | |||
| ) | |||
| op.create_table('maintenance', | |||
| sa.Column('MaintenanceID', sa.INTEGER(), autoincrement=True, nullable=False), | |||
| sa.Column('MaintenanceType', sa.VARCHAR(length=255), autoincrement=False, nullable=False), | |||
| sa.Column('MaintenanceDate', sa.DATE(), autoincrement=False, nullable=False), | |||
| sa.Column('ServiceDescription', sa.TEXT(), autoincrement=False, nullable=True), | |||
| sa.Column('PartsReplaced', sa.TEXT(), autoincrement=False, nullable=True), | |||
| sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False), | |||
| sa.PrimaryKeyConstraint('MaintenanceID', name='maintenance_pkey') | |||
| ) | |||
| op.create_table('reliability_data', | |||
| sa.Column('ComponentID', sa.INTEGER(), autoincrement=True, nullable=False), | |||
| sa.Column('ComponentName', sa.VARCHAR(length=255), autoincrement=False, nullable=False), | |||
| sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), | |||
| sa.Column('FailureRate', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), | |||
| sa.PrimaryKeyConstraint('ComponentID', name='reliability_data_pkey') | |||
| ) | |||
| op.create_table('component_failure_stats', | |||
| sa.Column('StatID', sa.INTEGER(), autoincrement=True, nullable=False), | |||
| sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False), | |||
| sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False), | |||
| sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False), | |||
| sa.Column('PossibleAction', sa.TEXT(), autoincrement=False, nullable=True), | |||
| sa.Column('Probability', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), | |||
| sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), | |||
| sa.PrimaryKeyConstraint('StatID', name='component_failure_stats_pkey') | |||
| ) | |||
| # ### end Alembic commands ### | |||
| @@ -1170,6 +1170,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] | |||
| def pipeline(self): | |||
| return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() | |||
| class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] | |||
| __tablename__ = "pipeline_customized_templates" | |||
| __table_args__ = ( | |||
| @@ -1205,6 +1206,7 @@ class Pipeline(Base): # type: ignore[name-defined] | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_by = db.Column(StringUUID, nullable=True) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @property | |||
| def dataset(self): | |||
| return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() | |||
| @@ -52,6 +52,7 @@ class AppMode(StrEnum): | |||
| ADVANCED_CHAT = "advanced-chat" | |||
| AGENT_CHAT = "agent-chat" | |||
| CHANNEL = "channel" | |||
| RAG_PIPELINE = "rag-pipeline" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "AppMode": | |||
| @@ -3,7 +3,7 @@ import logging | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import UTC, datetime | |||
| from enum import Enum, StrEnum | |||
| from typing import TYPE_CHECKING, Any, List, Optional, Self, Union | |||
| from typing import TYPE_CHECKING, Any, Optional, Self, Union | |||
| from uuid import uuid4 | |||
| from core.variables import utils as variable_utils | |||
| @@ -43,7 +43,7 @@ class WorkflowType(Enum): | |||
| WORKFLOW = "workflow" | |||
| CHAT = "chat" | |||
| RAG_PIPELINE = "rag_pipeline" | |||
| RAG_PIPELINE = "rag-pipeline" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "WorkflowType": | |||
| @@ -370,7 +370,7 @@ class Workflow(Base): | |||
| return results | |||
| @rag_pipeline_variables.setter | |||
| def rag_pipeline_variables(self, values: List[dict]) -> None: | |||
| def rag_pipeline_variables(self, values: list[dict]) -> None: | |||
| self._rag_pipeline_variables = json.dumps( | |||
| {item["variable"]: item for item in values}, | |||
| ensure_ascii=False, | |||
| @@ -1550,7 +1550,7 @@ class DocumentService: | |||
| @staticmethod | |||
| def build_document( | |||
| dataset: Dataset, | |||
| process_rule_id: str, | |||
| process_rule_id: str | None, | |||
| data_source_type: str, | |||
| document_form: str, | |||
| document_language: str, | |||
| @@ -0,0 +1,109 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any, Union | |||
| from configs import dify_config | |||
| from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator | |||
| from core.app.apps.pipeline.pipeline_generator import PipelineGenerator | |||
| from core.app.apps.workflow.app_generator import WorkflowAppGenerator | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from models.dataset import Pipeline | |||
| from models.model import Account, App, AppMode, EndUser | |||
| from models.workflow import Workflow | |||
| from services.rag_pipeline.rag_pipeline import RagPipelineService | |||
| class PipelineGenerateService: | |||
| @classmethod | |||
| def generate( | |||
| cls, | |||
| pipeline: Pipeline, | |||
| user: Union[Account, EndUser], | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| ): | |||
| """ | |||
| Pipeline Content Generate | |||
| :param pipeline: pipeline | |||
| :param user: user | |||
| :param args: args | |||
| :param invoke_from: invoke from | |||
| :param streaming: streaming | |||
| :return: | |||
| """ | |||
| try: | |||
| workflow = cls._get_workflow(pipeline, invoke_from) | |||
| return PipelineGenerator.convert_to_event_stream( | |||
| PipelineGenerator().generate( | |||
| pipeline=pipeline, | |||
| workflow=workflow, | |||
| user=user, | |||
| args=args, | |||
| invoke_from=invoke_from, | |||
| streaming=streaming, | |||
| call_depth=0, | |||
| workflow_thread_pool_id=None, | |||
| ), | |||
| ) | |||
| except Exception: | |||
| raise | |||
| @staticmethod | |||
| def _get_max_active_requests(app_model: App) -> int: | |||
| max_active_requests = app_model.max_active_requests | |||
| if max_active_requests is None: | |||
| max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) | |||
| return max_active_requests | |||
| @classmethod | |||
| def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): | |||
| if app_model.mode == AppMode.ADVANCED_CHAT.value: | |||
| workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) | |||
| return AdvancedChatAppGenerator.convert_to_event_stream( | |||
| AdvancedChatAppGenerator().single_iteration_generate( | |||
| app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming | |||
| ) | |||
| ) | |||
| elif app_model.mode == AppMode.WORKFLOW.value: | |||
| workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) | |||
| return AdvancedChatAppGenerator.convert_to_event_stream( | |||
| WorkflowAppGenerator().single_iteration_generate( | |||
| app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming | |||
| ) | |||
| ) | |||
| else: | |||
| raise ValueError(f"Invalid app mode {app_model.mode}") | |||
| @classmethod | |||
| def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True): | |||
| workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER) | |||
| return WorkflowAppGenerator.convert_to_event_stream( | |||
| WorkflowAppGenerator().single_loop_generate( | |||
| app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming | |||
| ) | |||
| ) | |||
| @classmethod | |||
| def _get_workflow(cls, pipeline: Pipeline, invoke_from: InvokeFrom) -> Workflow: | |||
| """ | |||
| Get workflow | |||
| :param pipeline: pipeline | |||
| :param invoke_from: invoke from | |||
| :return: | |||
| """ | |||
| rag_pipeline_service = RagPipelineService() | |||
| if invoke_from == InvokeFrom.DEBUGGER: | |||
| # fetch draft workflow by app_model | |||
| workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) | |||
| if not workflow: | |||
| raise ValueError("Workflow not initialized") | |||
| else: | |||
| # fetch published workflow by app_model | |||
| workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline) | |||
| if not workflow: | |||
| raise ValueError("Workflow not published") | |||
| return workflow | |||
| @@ -29,32 +29,31 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): | |||
| :param language: language | |||
| :return: | |||
| """ | |||
| pipeline_built_in_templates: list[PipelineBuiltInTemplate] = db.session.query(PipelineBuiltInTemplate).filter( | |||
| PipelineBuiltInTemplate.language == language | |||
| ).all() | |||
| pipeline_built_in_templates: list[PipelineBuiltInTemplate] = ( | |||
| db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all() | |||
| ) | |||
| recommended_pipelines_results = [] | |||
| for pipeline_built_in_template in pipeline_built_in_templates: | |||
| pipeline_model: Pipeline = pipeline_built_in_template.pipeline | |||
| recommended_pipeline_result = { | |||
| 'id': pipeline_built_in_template.id, | |||
| 'name': pipeline_built_in_template.name, | |||
| 'pipeline_id': pipeline_model.id, | |||
| 'description': pipeline_built_in_template.description, | |||
| 'icon': pipeline_built_in_template.icon, | |||
| 'copyright': pipeline_built_in_template.copyright, | |||
| 'privacy_policy': pipeline_built_in_template.privacy_policy, | |||
| 'position': pipeline_built_in_template.position, | |||
| "id": pipeline_built_in_template.id, | |||
| "name": pipeline_built_in_template.name, | |||
| "pipeline_id": pipeline_model.id, | |||
| "description": pipeline_built_in_template.description, | |||
| "icon": pipeline_built_in_template.icon, | |||
| "copyright": pipeline_built_in_template.copyright, | |||
| "privacy_policy": pipeline_built_in_template.privacy_policy, | |||
| "position": pipeline_built_in_template.position, | |||
| } | |||
| dataset: Dataset = pipeline_model.dataset | |||
| if dataset: | |||
| recommended_pipeline_result['chunk_structure'] = dataset.chunk_structure | |||
| recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure | |||
| recommended_pipelines_results.append(recommended_pipeline_result) | |||
| return {'pipeline_templates': recommended_pipelines_results} | |||
| return {"pipeline_templates": recommended_pipelines_results} | |||
| @classmethod | |||
| def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]: | |||
| @@ -64,6 +63,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): | |||
| :return: | |||
| """ | |||
| from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService | |||
| # is in public recommended list | |||
| pipeline_template = ( | |||
| db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first() | |||
| @@ -3,7 +3,7 @@ import threading | |||
| import time | |||
| from collections.abc import Callable, Generator, Sequence | |||
| from datetime import UTC, datetime | |||
| from typing import Any, Literal, Optional | |||
| from typing import Any, Optional | |||
| from uuid import uuid4 | |||
| from flask_login import current_user | |||
| @@ -46,7 +46,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi | |||
| class RagPipelineService: | |||
| @staticmethod | |||
| def get_pipeline_templates( | |||
| type: Literal["built-in", "customized"] = "built-in", language: str = "en-US" | |||
| type: str = "built-in", language: str = "en-US" | |||
| ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: | |||
| if type == "built-in": | |||
| mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE | |||
| @@ -358,11 +358,11 @@ class RagPipelineService: | |||
| return workflow_node_execution | |||
| def run_datasource_workflow_node( | |||
| def run_published_workflow_node( | |||
| self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account | |||
| ) -> WorkflowNodeExecution: | |||
| """ | |||
| Run published workflow datasource | |||
| Run published workflow node | |||
| """ | |||
| # fetch published workflow by app_model | |||
| published_workflow = self.get_published_workflow(pipeline=pipeline) | |||
| @@ -393,6 +393,41 @@ class RagPipelineService: | |||
| return workflow_node_execution | |||
| def run_datasource_workflow_node( | |||
| self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account | |||
| ) -> WorkflowNodeExecution: | |||
| """ | |||
| Run published workflow datasource | |||
| """ | |||
| # fetch published workflow by app_model | |||
| published_workflow = self.get_published_workflow(pipeline=pipeline) | |||
| if not published_workflow: | |||
| raise ValueError("Workflow not initialized") | |||
| # run draft workflow node | |||
| start_at = time.perf_counter() | |||
| datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) | |||
| if not datasource_node_data: | |||
| raise ValueError("Datasource node data not found") | |||
| from core.datasource.datasource_manager import DatasourceManager | |||
| datasource_runtime = DatasourceManager.get_datasource_runtime( | |||
| provider_id=datasource_node_data.get("provider_id"), | |||
| datasource_name=datasource_node_data.get("datasource_name"), | |||
| tenant_id=pipeline.tenant_id, | |||
| ) | |||
| result = datasource_runtime._invoke_first_step( | |||
| inputs=user_inputs, | |||
| provider_type=datasource_node_data.get("provider_type"), | |||
| user_id=account.id, | |||
| ) | |||
| return { | |||
| "result": result, | |||
| "provider_type": datasource_node_data.get("provider_type"), | |||
| } | |||
| def run_free_workflow_node( | |||
| self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] | |||
| ) -> WorkflowNodeExecution: | |||
| @@ -552,7 +587,7 @@ class RagPipelineService: | |||
| return workflow | |||
| def get_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: | |||
| def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: | |||
| """ | |||
| Get second step parameters of rag pipeline | |||
| """ | |||
| @@ -567,9 +602,33 @@ class RagPipelineService: | |||
| return {} | |||
| # get datasource provider | |||
| datasource_provider_variables = [item for item in rag_pipeline_variables | |||
| if item.get("belong_to_node_id") == node_id | |||
| or item.get("belong_to_node_id") == "shared"] | |||
| datasource_provider_variables = [ | |||
| item | |||
| for item in rag_pipeline_variables | |||
| if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" | |||
| ] | |||
| return datasource_provider_variables | |||
| def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: | |||
| """ | |||
| Get second step parameters of rag pipeline | |||
| """ | |||
| workflow = self.get_draft_workflow(pipeline=pipeline) | |||
| if not workflow: | |||
| raise ValueError("Workflow not initialized") | |||
| # get second step node | |||
| rag_pipeline_variables = workflow.rag_pipeline_variables | |||
| if not rag_pipeline_variables: | |||
| return {} | |||
| # get datasource provider | |||
| datasource_provider_variables = [ | |||
| item | |||
| for item in rag_pipeline_variables | |||
| if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" | |||
| ] | |||
| return datasource_provider_variables | |||
| def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: | |||