瀏覽代碼

r2

tags/2.0.0-beta.1
jyong 5 月之前
父節點
當前提交
e7c48c0b69

+ 36
- 8
api/controllers/console/datasets/rag_pipeline/rag_pipeline.py 查看文件

@@ -1,5 +1,6 @@
import logging

import yaml
from flask import request
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
@@ -12,10 +13,9 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from libs.login import login_required
from models.dataset import Pipeline, PipelineCustomizedTemplate
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService

logger = logging.getLogger(__name__)

@@ -84,8 +84,8 @@ class CustomizedPipelineTemplateApi(Resource):
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return pipeline_template, 200
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200

@setup_required
@login_required
@@ -106,13 +106,41 @@ class CustomizedPipelineTemplateApi(Resource):
)
if not template:
raise ValueError("Customized pipeline template not found.")
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found.")

dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
dsl = yaml.safe_load(template.yaml_content)
return {"data": dsl}, 200

class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, pipeline_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()
RagPipelineService.publish_customized_pipeline_template(pipeline_id, args)
return 200

api.add_resource(
PipelineTemplateListApi,

+ 92
- 62
api/core/app/apps/pipeline/pipeline_generator.py 查看文件

@@ -20,11 +20,11 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
@@ -32,6 +32,7 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from fields.document_fields import dataset_and_document_fields
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, Pipeline
from models.enums import WorkflowRunTriggeredFrom
@@ -54,7 +55,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Generator[Mapping | str, None, None] | None: ...
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...

@overload
def generate(
@@ -101,23 +102,18 @@ class PipelineGenerator(BaseAppGenerator):
pipeline=pipeline,
workflow=workflow,
)

# Add null check for dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))

for datasource_info in datasource_info_list:
workflow_run_id = str(uuid.uuid4())
document_id = None

# Add null check for dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")

if invoke_from == InvokeFrom.PUBLISHED:
documents = []
if invoke_from == InvokeFrom.PUBLISHED:
for datasource_info in datasource_info_list:
position = DocumentService.get_documents_position(dataset.id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
@@ -132,9 +128,15 @@ class PipelineGenerator(BaseAppGenerator):
document_form=dataset.chunk_structure,
)
db.session.add(document)
db.session.commit()
document_id = document.id
# init application generate entity
documents.append(document)
db.session.commit()

# run in child thread
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = None
if invoke_from == InvokeFrom.PUBLISHED:
document_id = documents[i].id
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
@@ -159,7 +161,6 @@ class PipelineGenerator(BaseAppGenerator):
workflow_run_id=workflow_run_id,
)

contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER:
@@ -183,6 +184,7 @@ class PipelineGenerator(BaseAppGenerator):
)
if invoke_from == InvokeFrom.DEBUGGER:
return self._generate(
flask_app=current_app._get_current_object(),# type: ignore
pipeline=pipeline,
workflow=workflow,
user=user,
@@ -194,21 +196,47 @@ class PipelineGenerator(BaseAppGenerator):
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
self._generate(
pipeline=pipeline,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
# run in child thread
thread = threading.Thread(
target=self._generate,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"pipeline": pipeline,
"workflow": workflow,
"user": user,
"application_generate_entity": application_generate_entity,
"invoke_from": invoke_from,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"streaming": streaming,
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)

thread.start()
# return batch, dataset, documents
return {
"batch": batch,
"dataset": PipelineDataset(
id=dataset.id,
name=dataset.name,
description=dataset.description,
chunk_structure=dataset.chunk_structure,
).model_dump(),
"documents": [PipelineDocument(
id=document.id,
position=document.position,
data_source_info=document.data_source_info,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump() for document in documents
]
}
def _generate(
self,
*,
flask_app: Flask,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
@@ -232,40 +260,42 @@ class PipelineGenerator(BaseAppGenerator):
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = PipelineQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=AppMode.RAG_PIPELINE,
)
print(user.id)
with flask_app.app_context():
# init queue manager
queue_manager = PipelineQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=AppMode.RAG_PIPELINE,
)

# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": contextvars.copy_context(),
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)
# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": contextvars.copy_context(),
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)

worker_thread.start()
worker_thread.start()

# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
)

return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

def single_iteration_generate(
self,
@@ -317,7 +347,6 @@ class PipelineGenerator(BaseAppGenerator):
call_depth=0,
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
@@ -338,6 +367,7 @@ class PipelineGenerator(BaseAppGenerator):
)

return self._generate(
flask_app=current_app._get_current_object(),# type: ignore
pipeline=pipeline,
workflow=workflow,
user=user,
@@ -399,7 +429,6 @@ class PipelineGenerator(BaseAppGenerator):
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())

@@ -421,6 +450,7 @@ class PipelineGenerator(BaseAppGenerator):
)

return self._generate(
flask_app=current_app._get_current_object(),# type: ignore
pipeline=pipeline,
workflow=workflow,
user=user,

+ 23
- 0
api/core/entities/knowledge_entities.py 查看文件

@@ -17,3 +17,26 @@ class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None


class PipelineDataset(BaseModel):
id: str
name: str
description: str
chunk_structure: str

class PipelineDocument(BaseModel):
id: str
position: int
data_source_info: dict
name: str
indexing_status: str
error: str
enabled: bool



class PipelineGenerateResponse(BaseModel):
batch: str
dataset: PipelineDataset
documents: list[PipelineDocument]

+ 4
- 2
api/core/repositories/sqlalchemy_workflow_node_execution_repository.py 查看文件

@@ -253,6 +253,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
@@ -274,7 +275,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
WorkflowNodeExecution.triggered_from == triggered_from,
)

if self._app_id:
@@ -308,6 +309,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) -> Sequence[NodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
@@ -325,7 +327,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
A list of NodeExecution instances
"""
# Get the database models using the new method
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config)
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from)

# Convert database models to domain models
domain_models = []

+ 1
- 0
api/fields/dataset_fields.py 查看文件

@@ -87,6 +87,7 @@ dataset_detail_fields = {
"runtime_mode": fields.String,
"chunk_structure": fields.String,
"icon_info": fields.Nested(icon_info_fields),
"is_published": fields.Boolean,
}

dataset_query_detail_fields = {

+ 13
- 5
api/models/dataset.py 查看文件

@@ -152,6 +152,8 @@ class Dataset(Base):

@property
def doc_form(self):
if self.chunk_structure:
return self.chunk_structure
document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
if document:
return document.doc_form
@@ -206,6 +208,13 @@ class Dataset(Base):
"external_knowledge_api_name": external_knowledge_api.name,
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
}
@property
def is_published(self):
if self.pipeline_id:
pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
if pipeline:
return pipeline.is_published
return False

@property
def doc_metadata(self):
@@ -1154,10 +1163,11 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)

id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
pipeline_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False)
chunk_structure = db.Column(db.String(255), nullable=False)
icon = db.Column(db.JSON, nullable=False)
yaml_content = db.Column(db.Text, nullable=False)
copyright = db.Column(db.String(255), nullable=False)
privacy_policy = db.Column(db.String(255), nullable=False)
position = db.Column(db.Integer, nullable=False)
@@ -1166,9 +1176,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

@property
def pipeline(self):
return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()


class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
@@ -1180,11 +1187,12 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]

id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
pipeline_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False)
chunk_structure = db.Column(db.String(255), nullable=False)
icon = db.Column(db.JSON, nullable=False)
position = db.Column(db.Integer, nullable=False)
yaml_content = db.Column(db.Text, nullable=False)
install_count = db.Column(db.Integer, nullable=False, default=0)
language = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

+ 5
- 5
api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py 查看文件

@@ -23,8 +23,8 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
result = self.fetch_pipeline_templates_from_builtin(language)
return result

def get_pipeline_template_detail(self, pipeline_id: str):
result = self.fetch_pipeline_template_detail_from_builtin(pipeline_id)
def get_pipeline_template_detail(self, template_id: str):
result = self.fetch_pipeline_template_detail_from_builtin(template_id)
return result

@classmethod
@@ -54,11 +54,11 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return builtin_data.get("pipeline_templates", {}).get(language, {})

@classmethod
def fetch_pipeline_template_detail_from_builtin(cls, pipeline_id: str) -> Optional[dict]:
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]:
"""
Fetch pipeline template detail from builtin.
:param pipeline_id: Pipeline ID
:param template_id: Template ID
:return:
"""
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(pipeline_id)
return builtin_data.get("pipeline_templates", {}).get(template_id)

+ 22
- 14
api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py 查看文件

@@ -1,12 +1,13 @@
from typing import Optional

from flask_login import current_user
import yaml

from extensions.ext_database import db
from models.dataset import Pipeline, PipelineCustomizedTemplate
from services.app_dsl_service import AppDslService
from models.dataset import PipelineCustomizedTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService


class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
@@ -35,13 +36,26 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param language: language
:return:
"""
pipeline_templates = (
pipeline_customized_templates = (
db.session.query(PipelineCustomizedTemplate)
.filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.all()
)
recommended_pipelines_results = []
for pipeline_customized_template in pipeline_customized_templates:

recommended_pipeline_result = {
"id": pipeline_customized_template.id,
"name": pipeline_customized_template.name,
"description": pipeline_customized_template.description,
"icon": pipeline_customized_template.icon,
"position": pipeline_customized_template.position,
"chunk_structure": pipeline_customized_template.chunk_structure,
}
recommended_pipelines_results.append(recommended_pipeline_result)

return {"pipeline_templates": recommended_pipelines_results}

return {"pipeline_templates": pipeline_templates}

@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
@@ -57,15 +71,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
if not pipeline_template:
return None

# get pipeline detail
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
if not pipeline or not pipeline.is_public:
return None

return {
"id": pipeline.id,
"name": pipeline.name,
"icon": pipeline.icon,
"mode": pipeline.mode,
"export_data": AppDslService.export_dsl(app_model=pipeline),
"id": pipeline_template.id,
"name": pipeline_template.name,
"icon": pipeline_template.icon,
"export_data": yaml.safe_load(pipeline_template.yaml_content),
}

+ 9
- 24
api/services/rag_pipeline/pipeline_template/database/database_retrieval.py 查看文件

@@ -1,7 +1,9 @@
from typing import Optional

import yaml

from extensions.ext_database import db
from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate
from models.dataset import PipelineBuiltInTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType

@@ -36,24 +38,18 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):

recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates:
pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline
if not pipeline_model:
continue

recommended_pipeline_result = {
"id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name,
"pipeline_id": pipeline_model.id,
"description": pipeline_built_in_template.description,
"icon": pipeline_built_in_template.icon,
"copyright": pipeline_built_in_template.copyright,
"privacy_policy": pipeline_built_in_template.privacy_policy,
"position": pipeline_built_in_template.position,
"chunk_structure": pipeline_built_in_template.chunk_structure,
}
dataset: Dataset | None = pipeline_model.dataset
if dataset:
recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
recommended_pipelines_results.append(recommended_pipeline_result)
recommended_pipelines_results.append(recommended_pipeline_result)

return {"pipeline_templates": recommended_pipelines_results}

@@ -64,8 +60,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param pipeline_id: Pipeline ID
:return:
"""
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService

# is in public recommended list
pipeline_template = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()
@@ -74,19 +68,10 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
if not pipeline_template:
return None

# get pipeline detail
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
if not pipeline or not pipeline.is_public:
return None

dataset: Dataset | None = pipeline.dataset
if not dataset:
return None

return {
"id": pipeline.id,
"name": pipeline.name,
"id": pipeline_template.id,
"name": pipeline_template.name,
"icon": pipeline_template.icon,
"chunk_structure": dataset.chunk_structure,
"export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline),
"chunk_structure": pipeline_template.chunk_structure,
"export_data": yaml.safe_load(pipeline_template.yaml_content),
}

+ 2
- 1
api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py 查看文件

@@ -1,4 +1,5 @@
from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.customized.customized_retrieval import CustomizedPipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
@@ -12,7 +13,7 @@ class PipelineTemplateRetrievalFactory:
case PipelineTemplateType.REMOTE:
return RemotePipelineTemplateRetrieval
case PipelineTemplateType.CUSTOMIZED:
return DatabasePipelineTemplateRetrieval
return CustomizedPipelineTemplateRetrieval
case PipelineTemplateType.DATABASE:
return DatabasePipelineTemplateRetrieval
case PipelineTemplateType.BUILTIN:

+ 34
- 10
api/services/rag_pipeline/rag_pipeline.py 查看文件

@@ -7,7 +7,7 @@ from typing import Any, Optional, cast
from uuid import uuid4

from flask_login import current_user
from sqlalchemy import select
from sqlalchemy import or_, select
from sqlalchemy.orm import Session

import contexts
@@ -47,16 +47,19 @@ from models.workflow import (
WorkflowType,
)
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
PipelineTemplateInfoEntity,
)
from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory


class RagPipelineService:
@staticmethod
@classmethod
def get_pipeline_templates(
type: str = "built-in", language: str = "en-US"
) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
cls, type: str = "built-in", language: str = "en-US"
) -> dict:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
@@ -64,14 +67,14 @@ class RagPipelineService:
if not result.get("pipeline_templates") and language != "en-US":
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])]
return result
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language)
return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])]
return result

@classmethod
@classmethod
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
"""
Get pipeline template detail.
@@ -684,7 +687,10 @@ class RagPipelineService:
base_query = db.session.query(WorkflowRun).filter(
WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
or_(
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value
)
)

if args.get("last_id"):
@@ -765,8 +771,26 @@ class RagPipelineService:

# Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id,
order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN)
# Convert domain models to database models
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]

return workflow_node_executions
@classmethod
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
"""
Publish customized pipeline template
"""
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError("Workflow not found")
db.session.commit()

+ 96
- 69
api/services/rag_pipeline/rag_pipeline_dsl_service.py 查看文件

@@ -1,5 +1,7 @@
import base64
from datetime import UTC, datetime
import hashlib
import json
import logging
import uuid
from collections.abc import Mapping
@@ -31,13 +33,12 @@ from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity,
)
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.rag_pipeline.rag_pipeline import RagPipelineService

logger = logging.getLogger(__name__)

@@ -206,12 +207,12 @@ class RagPipelineDslService:
status = _check_version_compatibility(imported_version)

# Extract app data
pipeline_data = data.get("pipeline")
pipeline_data = data.get("rag_pipeline")
if not pipeline_data:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Missing pipeline data in YAML content",
error="Missing rag_pipeline data in YAML content",
)

# If app_id is provided, check if it exists
@@ -256,7 +257,7 @@ class RagPipelineDslService:
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]

# Create or update app
# Create or update pipeline
pipeline = self._create_or_update_pipeline(
pipeline=pipeline,
data=data,
@@ -278,7 +279,9 @@ class RagPipelineDslService:
if node.get("data", {}).get("type") == "knowledge_index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
if not dataset:
if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure:
raise ValueError("Chunk structure is not compatible with the published pipeline")
else:
dataset = Dataset(
tenant_id=account.current_tenant_id,
name=name,
@@ -295,11 +298,6 @@ class RagPipelineDslService:
runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality":
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
@@ -540,33 +538,6 @@ class RagPipelineDslService:
icon_type = "emoji"
icon = str(pipeline_data.get("icon", ""))

if pipeline:
# Update existing pipeline
pipeline.name = pipeline_data.get("name", pipeline.name)
pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.updated_by = account.id
else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")

# Create new app
pipeline = Pipeline()
pipeline.id = str(uuid4())
pipeline.tenant_id = account.current_tenant_id
pipeline.name = pipeline_data.get("name", "")
pipeline.description = pipeline_data.get("description", "")
pipeline.created_by = account.id
pipeline.updated_by = account.id

self._session.add(pipeline)
self._session.commit()
# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
)

# Initialize pipeline based on mode
workflow_data = data.get("workflow")
@@ -583,12 +554,7 @@ class RagPipelineDslService:
]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])

rag_pipeline_service = RagPipelineService()
current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None

graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
@@ -599,20 +565,78 @@ class RagPipelineDslService:
if (
decrypted_id := self.decrypt_dataset_id(
encrypted_data=dataset_id,
tenant_id=pipeline.tenant_id,
tenant_id=account.current_tenant_id,
)
)
]
rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
graph=workflow_data.get("graph", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables_list,

if pipeline:
# Update existing pipeline
pipeline.name = pipeline_data.get("name", pipeline.name)
pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.updated_by = account.id

else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")

# Create new app
pipeline = Pipeline()
pipeline.id = str(uuid4())
pipeline.tenant_id = account.current_tenant_id
pipeline.name = pipeline_data.get("name", "")
pipeline.description = pipeline_data.get("description", "")
pipeline.created_by = account.id
pipeline.updated_by = account.id

self._session.add(pipeline)
self._session.commit()
# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
)
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
)

# create draft workflow if not found
if not workflow:
workflow = Workflow(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE.value,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables_list,
)
db.session.add(workflow)
db.session.flush()
pipeline.workflow_id = workflow.id
else:
workflow.graph = json.dumps(graph)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
workflow.rag_pipeline_variables = rag_pipeline_variables_list
# commit db session changes
db.session.commit()


return pipeline

@classmethod
@@ -623,16 +647,19 @@ class RagPipelineDslService:
:param include_secret: Whether include secret variable
:return:
"""
dataset = pipeline.dataset
if not dataset:
raise ValueError("Missing dataset for rag pipeline")
icon_info = dataset.icon_info
export_data = {
"version": CURRENT_DSL_VERSION,
"kind": "rag_pipeline",
"pipeline": {
"name": pipeline.name,
"mode": pipeline.mode,
"icon": "🤖" if pipeline.icon_type == "image" else pipeline.icon,
"icon_background": "#FFEAD5" if pipeline.icon_type == "image" else pipeline.icon_background,
"icon": icon_info.get("icon", "📙") if icon_info else "📙",
"icon_type": icon_info.get("icon_type", "emoji") if icon_info else "emoji",
"icon_background": icon_info.get("icon_background", "#FFEAD5") if icon_info else "#FFEAD5",
"description": pipeline.description,
"use_icon_as_answer_icon": pipeline.use_icon_as_answer_icon,
},
}

@@ -647,8 +674,16 @@ class RagPipelineDslService:
:param export_data: export data
:param pipeline: Pipeline instance
"""
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)

workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
)
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")

@@ -855,14 +890,6 @@ class RagPipelineDslService:
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)

dataset = Dataset(
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag-pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user)
@@ -870,11 +897,11 @@ class RagPipelineDslService:
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset,
dataset=None,
)
return {
"id": rag_pipeline_import_info.id,
"dataset_id": dataset.id,
"dataset_id": rag_pipeline_import_info.dataset_id,
"pipeline_id": rag_pipeline_import_info.pipeline_id,
"status": rag_pipeline_import_info.status,
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,

Loading…
取消
儲存