Bladeren bron

Fix/dsl kb encrypt (#17353)

tags/1.2.0
Dongyu Li 7 maanden geleden
bovenliggende
commit
2e9997110a
No account linked to committer's email address
1 gewijzigde bestanden met toevoegingen van 48 en 1 verwijderingen
  1. 48
    1
      api/services/app_dsl_service.py

+ 48
- 1
api/services/app_dsl_service.py Bestand weergeven

@@ -1,3 +1,5 @@
import base64
import hashlib
import logging
import uuid
from collections.abc import Mapping
@@ -7,6 +9,8 @@ from urllib.parse import urlparse
from uuid import uuid4

import yaml # type: ignore
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from packaging import version
from pydantic import BaseModel, Field
from sqlalchemy import select
@@ -478,6 +482,15 @@ class AppDslService:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
]
workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow_data.get("graph", {}),
@@ -552,7 +565,15 @@ class AppDslService:
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")

export_data["workflow"] = workflow.to_dict(include_secret=include_secret)
workflow_dict = workflow.to_dict(include_secret=include_secret)
for node in workflow_dict.get("graph", {}).get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids
]
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
@@ -724,3 +745,29 @@ class AppDslService:
return []

return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)

@staticmethod
def _generate_aes_key(tenant_id: str) -> bytes:
"""Generate AES key based on tenant_id"""
return hashlib.sha256(tenant_id.encode()).digest()

@classmethod
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
"""Encrypt dataset_id using AES-CBC mode"""
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
return base64.b64encode(ct_bytes).decode()

@classmethod
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
"""AES decryption"""
try:
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
return pt.decode()
except Exception:
return None

Laden…
Annuleren
Opslaan