Browse Source

fix(rag-pipeline-dsl): dsl import session error

tags/2.0.0-beta.1
Novice 1 month ago
parent
commit
68f4d4b97c

+ 8
- 4
api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py View File

@@ -1,5 +1,6 @@
from flask_login import current_user # type: ignore # type: ignore
from flask_restx import Resource, marshal, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden

import services
@@ -10,6 +11,7 @@ from controllers.console.wraps import (
cloud_edition_billing_rate_limit_check,
setup_required,
)
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models.dataset import DatasetPermissionEnum
@@ -64,10 +66,12 @@ class CreateRagPipelineDatasetApi(Resource):
yaml_content=args["yaml_content"],
)
try:
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id,

+ 13
- 9
api/core/app/apps/pipeline/pipeline_generator.py View File

@@ -110,9 +110,11 @@ class PipelineGenerator(BaseAppGenerator):
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# Add null check for dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")

with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
@@ -360,9 +362,10 @@ class PipelineGenerator(BaseAppGenerator):
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)

dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")

# init application generate entity - use RagPipelineGenerateEntity instead
application_generate_entity = RagPipelineGenerateEntity(
@@ -446,9 +449,10 @@ class PipelineGenerator(BaseAppGenerator):
if args.get("inputs") is None:
raise ValueError("inputs is required")

dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")

# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(

+ 5
- 5
api/models/dataset.py View File

@@ -15,7 +15,7 @@ from typing import Any, Optional, cast
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, Session, mapped_column

from configs import dify_config
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
@@ -1286,9 +1286,8 @@ class Pipeline(Base): # type: ignore[name-defined]
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
def retrieve_dataset(self, session: Session):
return session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()


class DocumentPipelineExecutionLog(Base):
@@ -1308,6 +1307,7 @@ class DocumentPipelineExecutionLog(Base):
created_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())


class PipelineRecommendedPlugin(Base):
__tablename__ = "pipeline_recommended_plugins"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
@@ -1318,4 +1318,4 @@ class PipelineRecommendedPlugin(Base):
position = db.Column(db.Integer, nullable=False, default=0)
active = db.Column(db.Boolean, nullable=False, default=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

+ 11
- 7
api/services/rag_pipeline/rag_pipeline.py View File

@@ -352,9 +352,10 @@ class RagPipelineService:
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)

# update dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session=session)
if not dataset:
raise ValueError("Dataset not found")
DatasetService.update_rag_pipeline_dataset_settings(
session=session,
dataset=dataset,
@@ -1110,9 +1111,10 @@ class RagPipelineService:
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError("Workflow not found")
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session=session)
if not dataset:
raise ValueError("Dataset not found")

# check template name is exist
template_name = args.get("name")
@@ -1136,7 +1138,9 @@ class RagPipelineService:

from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService

dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)

pipeline_customized_template = PipelineCustomizedTemplate(
name=args.get("name"),

+ 40
- 61
api/services/rag_pipeline/rag_pipeline_dsl_service.py View File

@@ -30,7 +30,6 @@ from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account
@@ -235,10 +234,7 @@ class RagPipelineDslService:
status=ImportStatus.FAILED,
error="Pipeline not found",
)
dataset = pipeline.dataset
if dataset:
self._session.merge(dataset)
dataset_name = dataset.name
dataset = pipeline.retrieve_dataset(session=self._session)

# If major version mismatch, store import info in Redis
if status == ImportStatus.PENDING:
@@ -300,7 +296,7 @@ class RagPipelineDslService:
):
raise ValueError("Chunk structure is not compatible with the published pipeline")
if not dataset:
datasets = db.session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
names = [dataset.name for dataset in datasets]
generate_name = generate_incremental_name(names, name)
dataset = Dataset(
@@ -321,7 +317,7 @@ class RagPipelineDslService:
)
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
self._session.query(DatasetCollectionBinding)
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
@@ -339,8 +335,8 @@ class RagPipelineDslService:
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
db.session.add(dataset_collection_binding)
db.session.commit()
self._session.add(dataset_collection_binding)
self._session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
@@ -454,7 +450,7 @@ class RagPipelineDslService:
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
self._session.query(DatasetCollectionBinding)
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
@@ -472,8 +468,8 @@ class RagPipelineDslService:
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
db.session.add(dataset_collection_binding)
db.session.commit()
self._session.add(dataset_collection_binding)
self._session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
@@ -538,18 +534,10 @@ class RagPipelineDslService:
account: Account,
dependencies: Optional[list[PluginDependency]] = None,
) -> Pipeline:
"""Create a new app or update an existing one."""
if not account.current_tenant_id:
raise ValueError("Tenant id is required")
"""Create a new app or update an existing one."""
pipeline_data = data.get("rag_pipeline", {})
# Set icon type
icon_type_value = pipeline_data.get("icon_type")
if icon_type_value in ["emoji", "link"]:
icon_type = icon_type_value
else:
icon_type = "emoji"
icon = str(pipeline_data.get("icon", ""))

# Initialize pipeline based on mode
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
@@ -609,7 +597,7 @@ class RagPipelineDslService:
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
)
workflow = (
db.session.query(Workflow)
self._session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
@@ -632,8 +620,8 @@ class RagPipelineDslService:
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables_list,
)
db.session.add(workflow)
db.session.flush()
self._session.add(workflow)
self._session.flush()
pipeline.workflow_id = workflow.id
else:
workflow.graph = json.dumps(graph)
@@ -643,19 +631,18 @@ class RagPipelineDslService:
workflow.conversation_variables = conversation_variables
workflow.rag_pipeline_variables = rag_pipeline_variables_list
# commit db session changes
db.session.commit()
self._session.commit()

return pipeline

@classmethod
def export_rag_pipeline_dsl(cls, pipeline: Pipeline, include_secret: bool = False) -> str:
def export_rag_pipeline_dsl(self, pipeline: Pipeline, include_secret: bool = False) -> str:
"""
Export pipeline
:param pipeline: Pipeline instance
:param include_secret: Whether include secret variable
:return:
"""
dataset = pipeline.dataset
dataset = pipeline.retrieve_dataset(session=self._session)
if not dataset:
raise ValueError("Missing dataset for rag pipeline")
icon_info = dataset.icon_info
@@ -672,12 +659,11 @@ class RagPipelineDslService:
},
}

cls._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)
self._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)

return yaml.dump(export_data, allow_unicode=True) # type: ignore

@classmethod
def _append_workflow_export_data(cls, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
"""
Append workflow export data
:param export_data: export data
@@ -685,7 +671,7 @@ class RagPipelineDslService:
"""

workflow = (
db.session.query(Workflow)
self._session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
@@ -701,11 +687,11 @@ class RagPipelineDslService:
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
for dataset_id in dataset_ids
]
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)
dependencies = self._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
@@ -713,19 +699,17 @@ class RagPipelineDslService:
)
]

@classmethod
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]:
"""
Extract dependencies from workflow
:param workflow: Workflow instance
:return: dependencies list format like ["langgenius/google"]
"""
graph = workflow.graph_dict
dependencies = cls._extract_dependencies_from_workflow_graph(graph)
dependencies = self._extract_dependencies_from_workflow_graph(graph)
return dependencies

@classmethod
def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]:
def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]:
"""
Extract dependencies from workflow graph
:param graph: Workflow graph
@@ -882,25 +866,22 @@ class RagPipelineDslService:

return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)

@staticmethod
def _generate_aes_key(tenant_id: str) -> bytes:
def _generate_aes_key(self, tenant_id: str) -> bytes:
"""Generate AES key based on tenant_id"""
return hashlib.sha256(tenant_id.encode()).digest()

@classmethod
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
def encrypt_dataset_id(self, dataset_id: str, tenant_id: str) -> str:
"""Encrypt dataset_id using AES-CBC mode"""
key = cls._generate_aes_key(tenant_id)
key = self._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
return base64.b64encode(ct_bytes).decode()

@classmethod
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
def decrypt_dataset_id(self, encrypted_data: str, tenant_id: str) -> str | None:
"""AES decryption"""
try:
key = cls._generate_aes_key(tenant_id)
key = self._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
@@ -908,39 +889,37 @@ class RagPipelineDslService:
except Exception:
return None

@staticmethod
def create_rag_pipeline_dataset(
self,
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
if rag_pipeline_dataset_create_entity.name:
# check if dataset name already exists
if (
db.session.query(Dataset)
self._session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
else:
# generate a random name as Untitled 1 2 3 ...
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all()
datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all()
names = [dataset.name for dataset in datasets]
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
names,
"Untitled",
)

with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=None,
dataset_name=rag_pipeline_dataset_create_entity.name,
icon_info=rag_pipeline_dataset_create_entity.icon_info,
)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=None,
dataset_name=rag_pipeline_dataset_create_entity.name,
icon_info=rag_pipeline_dataset_create_entity.icon_info,
)
return {
"id": rag_pipeline_import_info.id,
"dataset_id": rag_pipeline_import_info.dataset_id,

Loading…
Cancel
Save