|
|
|
@@ -30,7 +30,6 @@ from core.workflow.nodes.llm.entities import LLMNodeData |
|
|
|
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData |
|
|
|
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData |
|
|
|
from core.workflow.nodes.tool.entities import ToolNodeData |
|
|
|
from extensions.ext_database import db |
|
|
|
from extensions.ext_redis import redis_client |
|
|
|
from factories import variable_factory |
|
|
|
from models import Account |
|
|
|
@@ -235,10 +234,7 @@ class RagPipelineDslService: |
|
|
|
status=ImportStatus.FAILED, |
|
|
|
error="Pipeline not found", |
|
|
|
) |
|
|
|
dataset = pipeline.dataset |
|
|
|
if dataset: |
|
|
|
self._session.merge(dataset) |
|
|
|
dataset_name = dataset.name |
|
|
|
dataset = pipeline.retrieve_dataset(session=self._session) |
|
|
|
|
|
|
|
# If major version mismatch, store import info in Redis |
|
|
|
if status == ImportStatus.PENDING: |
|
|
|
@@ -300,7 +296,7 @@ class RagPipelineDslService: |
|
|
|
): |
|
|
|
raise ValueError("Chunk structure is not compatible with the published pipeline") |
|
|
|
if not dataset: |
|
|
|
datasets = db.session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all() |
|
|
|
datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all() |
|
|
|
names = [dataset.name for dataset in datasets] |
|
|
|
generate_name = generate_incremental_name(names, name) |
|
|
|
dataset = Dataset( |
|
|
|
@@ -321,7 +317,7 @@ class RagPipelineDslService: |
|
|
|
) |
|
|
|
if knowledge_configuration.indexing_technique == "high_quality": |
|
|
|
dataset_collection_binding = ( |
|
|
|
db.session.query(DatasetCollectionBinding) |
|
|
|
self._session.query(DatasetCollectionBinding) |
|
|
|
.filter( |
|
|
|
DatasetCollectionBinding.provider_name |
|
|
|
== knowledge_configuration.embedding_model_provider, |
|
|
|
@@ -339,8 +335,8 @@ class RagPipelineDslService: |
|
|
|
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), |
|
|
|
type="dataset", |
|
|
|
) |
|
|
|
db.session.add(dataset_collection_binding) |
|
|
|
db.session.commit() |
|
|
|
self._session.add(dataset_collection_binding) |
|
|
|
self._session.commit() |
|
|
|
dataset_collection_binding_id = dataset_collection_binding.id |
|
|
|
dataset.collection_binding_id = dataset_collection_binding_id |
|
|
|
dataset.embedding_model = knowledge_configuration.embedding_model |
|
|
|
@@ -454,7 +450,7 @@ class RagPipelineDslService: |
|
|
|
dataset.chunk_structure = knowledge_configuration.chunk_structure |
|
|
|
if knowledge_configuration.indexing_technique == "high_quality": |
|
|
|
dataset_collection_binding = ( |
|
|
|
db.session.query(DatasetCollectionBinding) |
|
|
|
self._session.query(DatasetCollectionBinding) |
|
|
|
.filter( |
|
|
|
DatasetCollectionBinding.provider_name |
|
|
|
== knowledge_configuration.embedding_model_provider, |
|
|
|
@@ -472,8 +468,8 @@ class RagPipelineDslService: |
|
|
|
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), |
|
|
|
type="dataset", |
|
|
|
) |
|
|
|
db.session.add(dataset_collection_binding) |
|
|
|
db.session.commit() |
|
|
|
self._session.add(dataset_collection_binding) |
|
|
|
self._session.commit() |
|
|
|
dataset_collection_binding_id = dataset_collection_binding.id |
|
|
|
dataset.collection_binding_id = dataset_collection_binding_id |
|
|
|
dataset.embedding_model = knowledge_configuration.embedding_model |
|
|
|
@@ -538,18 +534,10 @@ class RagPipelineDslService: |
|
|
|
account: Account, |
|
|
|
dependencies: Optional[list[PluginDependency]] = None, |
|
|
|
) -> Pipeline: |
|
|
|
"""Create a new app or update an existing one.""" |
|
|
|
if not account.current_tenant_id: |
|
|
|
raise ValueError("Tenant id is required") |
|
|
|
"""Create a new app or update an existing one.""" |
|
|
|
pipeline_data = data.get("rag_pipeline", {}) |
|
|
|
# Set icon type |
|
|
|
icon_type_value = pipeline_data.get("icon_type") |
|
|
|
if icon_type_value in ["emoji", "link"]: |
|
|
|
icon_type = icon_type_value |
|
|
|
else: |
|
|
|
icon_type = "emoji" |
|
|
|
icon = str(pipeline_data.get("icon", "")) |
|
|
|
|
|
|
|
# Initialize pipeline based on mode |
|
|
|
workflow_data = data.get("workflow") |
|
|
|
if not workflow_data or not isinstance(workflow_data, dict): |
|
|
|
@@ -609,7 +597,7 @@ class RagPipelineDslService: |
|
|
|
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(), |
|
|
|
) |
|
|
|
workflow = ( |
|
|
|
db.session.query(Workflow) |
|
|
|
self._session.query(Workflow) |
|
|
|
.filter( |
|
|
|
Workflow.tenant_id == pipeline.tenant_id, |
|
|
|
Workflow.app_id == pipeline.id, |
|
|
|
@@ -632,8 +620,8 @@ class RagPipelineDslService: |
|
|
|
conversation_variables=conversation_variables, |
|
|
|
rag_pipeline_variables=rag_pipeline_variables_list, |
|
|
|
) |
|
|
|
db.session.add(workflow) |
|
|
|
db.session.flush() |
|
|
|
self._session.add(workflow) |
|
|
|
self._session.flush() |
|
|
|
pipeline.workflow_id = workflow.id |
|
|
|
else: |
|
|
|
workflow.graph = json.dumps(graph) |
|
|
|
@@ -643,19 +631,18 @@ class RagPipelineDslService: |
|
|
|
workflow.conversation_variables = conversation_variables |
|
|
|
workflow.rag_pipeline_variables = rag_pipeline_variables_list |
|
|
|
# commit db session changes |
|
|
|
db.session.commit() |
|
|
|
self._session.commit() |
|
|
|
|
|
|
|
return pipeline |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def export_rag_pipeline_dsl(cls, pipeline: Pipeline, include_secret: bool = False) -> str: |
|
|
|
def export_rag_pipeline_dsl(self, pipeline: Pipeline, include_secret: bool = False) -> str: |
|
|
|
""" |
|
|
|
Export pipeline |
|
|
|
:param pipeline: Pipeline instance |
|
|
|
:param include_secret: Whether include secret variable |
|
|
|
:return: |
|
|
|
""" |
|
|
|
dataset = pipeline.dataset |
|
|
|
dataset = pipeline.retrieve_dataset(session=self._session) |
|
|
|
if not dataset: |
|
|
|
raise ValueError("Missing dataset for rag pipeline") |
|
|
|
icon_info = dataset.icon_info |
|
|
|
@@ -672,12 +659,11 @@ class RagPipelineDslService: |
|
|
|
}, |
|
|
|
} |
|
|
|
|
|
|
|
cls._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret) |
|
|
|
self._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret) |
|
|
|
|
|
|
|
return yaml.dump(export_data, allow_unicode=True) # type: ignore |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def _append_workflow_export_data(cls, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None: |
|
|
|
def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None: |
|
|
|
""" |
|
|
|
Append workflow export data |
|
|
|
:param export_data: export data |
|
|
|
@@ -685,7 +671,7 @@ class RagPipelineDslService: |
|
|
|
""" |
|
|
|
|
|
|
|
workflow = ( |
|
|
|
db.session.query(Workflow) |
|
|
|
self._session.query(Workflow) |
|
|
|
.filter( |
|
|
|
Workflow.tenant_id == pipeline.tenant_id, |
|
|
|
Workflow.app_id == pipeline.id, |
|
|
|
@@ -701,11 +687,11 @@ class RagPipelineDslService: |
|
|
|
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: |
|
|
|
dataset_ids = node["data"].get("dataset_ids", []) |
|
|
|
node["data"]["dataset_ids"] = [ |
|
|
|
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id) |
|
|
|
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id) |
|
|
|
for dataset_id in dataset_ids |
|
|
|
] |
|
|
|
export_data["workflow"] = workflow_dict |
|
|
|
dependencies = cls._extract_dependencies_from_workflow(workflow) |
|
|
|
dependencies = self._extract_dependencies_from_workflow(workflow) |
|
|
|
export_data["dependencies"] = [ |
|
|
|
jsonable_encoder(d.model_dump()) |
|
|
|
for d in DependenciesAnalysisService.generate_dependencies( |
|
|
|
@@ -713,19 +699,17 @@ class RagPipelineDslService: |
|
|
|
) |
|
|
|
] |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]: |
|
|
|
def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]: |
|
|
|
""" |
|
|
|
Extract dependencies from workflow |
|
|
|
:param workflow: Workflow instance |
|
|
|
:return: dependencies list format like ["langgenius/google"] |
|
|
|
""" |
|
|
|
graph = workflow.graph_dict |
|
|
|
dependencies = cls._extract_dependencies_from_workflow_graph(graph) |
|
|
|
dependencies = self._extract_dependencies_from_workflow_graph(graph) |
|
|
|
return dependencies |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]: |
|
|
|
def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]: |
|
|
|
""" |
|
|
|
Extract dependencies from workflow graph |
|
|
|
:param graph: Workflow graph |
|
|
|
@@ -882,25 +866,22 @@ class RagPipelineDslService: |
|
|
|
|
|
|
|
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _generate_aes_key(tenant_id: str) -> bytes: |
|
|
|
def _generate_aes_key(self, tenant_id: str) -> bytes: |
|
|
|
"""Generate AES key based on tenant_id""" |
|
|
|
return hashlib.sha256(tenant_id.encode()).digest() |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str: |
|
|
|
def encrypt_dataset_id(self, dataset_id: str, tenant_id: str) -> str: |
|
|
|
"""Encrypt dataset_id using AES-CBC mode""" |
|
|
|
key = cls._generate_aes_key(tenant_id) |
|
|
|
key = self._generate_aes_key(tenant_id) |
|
|
|
iv = key[:16] |
|
|
|
cipher = AES.new(key, AES.MODE_CBC, iv) |
|
|
|
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size)) |
|
|
|
return base64.b64encode(ct_bytes).decode() |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None: |
|
|
|
def decrypt_dataset_id(self, encrypted_data: str, tenant_id: str) -> str | None: |
|
|
|
"""AES decryption""" |
|
|
|
try: |
|
|
|
key = cls._generate_aes_key(tenant_id) |
|
|
|
key = self._generate_aes_key(tenant_id) |
|
|
|
iv = key[:16] |
|
|
|
cipher = AES.new(key, AES.MODE_CBC, iv) |
|
|
|
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size) |
|
|
|
@@ -908,39 +889,37 @@ class RagPipelineDslService: |
|
|
|
except Exception: |
|
|
|
return None |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def create_rag_pipeline_dataset( |
|
|
|
self, |
|
|
|
tenant_id: str, |
|
|
|
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, |
|
|
|
): |
|
|
|
if rag_pipeline_dataset_create_entity.name: |
|
|
|
# check if dataset name already exists |
|
|
|
if ( |
|
|
|
db.session.query(Dataset) |
|
|
|
self._session.query(Dataset) |
|
|
|
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) |
|
|
|
.first() |
|
|
|
): |
|
|
|
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.") |
|
|
|
else: |
|
|
|
# generate a random name as Untitled 1 2 3 ... |
|
|
|
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all() |
|
|
|
datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all() |
|
|
|
names = [dataset.name for dataset in datasets] |
|
|
|
rag_pipeline_dataset_create_entity.name = generate_incremental_name( |
|
|
|
names, |
|
|
|
"Untitled", |
|
|
|
) |
|
|
|
|
|
|
|
with Session(db.engine) as session: |
|
|
|
rag_pipeline_dsl_service = RagPipelineDslService(session) |
|
|
|
account = cast(Account, current_user) |
|
|
|
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( |
|
|
|
account=account, |
|
|
|
import_mode=ImportMode.YAML_CONTENT.value, |
|
|
|
yaml_content=rag_pipeline_dataset_create_entity.yaml_content, |
|
|
|
dataset=None, |
|
|
|
dataset_name=rag_pipeline_dataset_create_entity.name, |
|
|
|
icon_info=rag_pipeline_dataset_create_entity.icon_info, |
|
|
|
) |
|
|
|
account = cast(Account, current_user) |
|
|
|
rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline( |
|
|
|
account=account, |
|
|
|
import_mode=ImportMode.YAML_CONTENT.value, |
|
|
|
yaml_content=rag_pipeline_dataset_create_entity.yaml_content, |
|
|
|
dataset=None, |
|
|
|
dataset_name=rag_pipeline_dataset_create_entity.name, |
|
|
|
icon_info=rag_pipeline_dataset_create_entity.icon_info, |
|
|
|
) |
|
|
|
return { |
|
|
|
"id": rag_pipeline_import_info.id, |
|
|
|
"dataset_id": rag_pipeline_import_info.dataset_id, |