| @@ -1,3 +1,5 @@ | |||
| import base64 | |||
| import hashlib | |||
| import logging | |||
| import uuid | |||
| from collections.abc import Mapping | |||
| @@ -7,6 +9,8 @@ from urllib.parse import urlparse | |||
| from uuid import uuid4 | |||
| import yaml # type: ignore | |||
| from Crypto.Cipher import AES | |||
| from Crypto.Util.Padding import pad, unpad | |||
| from packaging import version | |||
| from pydantic import BaseModel, Field | |||
| from sqlalchemy import select | |||
| @@ -478,6 +482,15 @@ class AppDslService: | |||
| unique_hash = current_draft_workflow.unique_hash | |||
| else: | |||
| unique_hash = None | |||
| graph = workflow_data.get("graph", {}) | |||
| for node in graph.get("nodes", []): | |||
| if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: | |||
| dataset_ids = node["data"].get("dataset_ids", []) | |||
| node["data"]["dataset_ids"] = [ | |||
| decrypted_id | |||
| for dataset_id in dataset_ids | |||
| if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id)) | |||
| ] | |||
| workflow_service.sync_draft_workflow( | |||
| app_model=app, | |||
| graph=workflow_data.get("graph", {}), | |||
| @@ -552,7 +565,15 @@ class AppDslService: | |||
| if not workflow: | |||
| raise ValueError("Missing draft workflow configuration, please check.") | |||
| export_data["workflow"] = workflow.to_dict(include_secret=include_secret) | |||
| workflow_dict = workflow.to_dict(include_secret=include_secret) | |||
| for node in workflow_dict.get("graph", {}).get("nodes", []): | |||
| if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: | |||
| dataset_ids = node["data"].get("dataset_ids", []) | |||
| node["data"]["dataset_ids"] = [ | |||
| cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) | |||
| for dataset_id in dataset_ids | |||
| ] | |||
| export_data["workflow"] = workflow_dict | |||
| dependencies = cls._extract_dependencies_from_workflow(workflow) | |||
| export_data["dependencies"] = [ | |||
| jsonable_encoder(d.model_dump()) | |||
| @@ -724,3 +745,29 @@ class AppDslService: | |||
| return [] | |||
| return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) | |||
| @staticmethod | |||
| def _generate_aes_key(tenant_id: str) -> bytes: | |||
| """Generate AES key based on tenant_id""" | |||
| return hashlib.sha256(tenant_id.encode()).digest() | |||
| @classmethod | |||
| def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str: | |||
| """Encrypt dataset_id using AES-CBC mode""" | |||
| key = cls._generate_aes_key(tenant_id) | |||
| iv = key[:16] | |||
| cipher = AES.new(key, AES.MODE_CBC, iv) | |||
| ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size)) | |||
| return base64.b64encode(ct_bytes).decode() | |||
| @classmethod | |||
| def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None: | |||
| """AES decryption""" | |||
| try: | |||
| key = cls._generate_aes_key(tenant_id) | |||
| iv = key[:16] | |||
| cipher = AES.new(key, AES.MODE_CBC, iv) | |||
| pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size) | |||
| return pt.decode() | |||
| except Exception: | |||
| return None | |||