| @@ -1,11 +0,0 @@ | |||
| from tests.integration_tests.utils.parent_class import ParentClass | |||
| class ChildClass(ParentClass): | |||
| """Test child class for module import helper tests""" | |||
| def __init__(self, name): | |||
| super().__init__(name) | |||
| def get_name(self): | |||
| return f"Child: {self.name}" | |||
| @@ -532,7 +532,7 @@ class PublishedWorkflowApi(Resource): | |||
| ) | |||
| app_model.workflow_id = workflow.id | |||
| db.session.commit() | |||
| db.session.commit() # NOTE: this is necessary for update app_model.workflow_id | |||
| workflow_created_at = TimestampField().format(workflow.created_at) | |||
| @@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound | |||
| from controllers.console import api | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.rag.extractor.entity.datasource_type import DatasourceType | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.extractor.notion_extractor import NotionExtractor | |||
| from extensions.ext_database import db | |||
| @@ -214,7 +215,7 @@ class DataSourceNotionApi(Resource): | |||
| workspace_id = notion_info["workspace_id"] | |||
| for page in notion_info["pages"]: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| datasource_type=DatasourceType.NOTION.value, | |||
| notion_info={ | |||
| "notion_workspace_id": workspace_id, | |||
| "notion_obj_id": page["page_id"], | |||
| @@ -21,6 +21,7 @@ from core.indexing_runner import IndexingRunner | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.provider_manager import ProviderManager | |||
| from core.rag.datasource.vdb.vector_type import VectorType | |||
| from core.rag.extractor.entity.datasource_type import DatasourceType | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from extensions.ext_database import db | |||
| @@ -422,7 +423,9 @@ class DatasetIndexingEstimateApi(Resource): | |||
| if file_details: | |||
| for file_detail in file_details: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] | |||
| datasource_type=DatasourceType.FILE.value, | |||
| upload_file=file_detail, | |||
| document_model=args["doc_form"], | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| elif args["info_list"]["data_source_type"] == "notion_import": | |||
| @@ -431,7 +434,7 @@ class DatasetIndexingEstimateApi(Resource): | |||
| workspace_id = notion_info["workspace_id"] | |||
| for page in notion_info["pages"]: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| datasource_type=DatasourceType.NOTION.value, | |||
| notion_info={ | |||
| "notion_workspace_id": workspace_id, | |||
| "notion_obj_id": page["page_id"], | |||
| @@ -445,7 +448,7 @@ class DatasetIndexingEstimateApi(Resource): | |||
| website_info_list = args["info_list"]["website_info_list"] | |||
| for url in website_info_list["urls"]: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="website_crawl", | |||
| datasource_type=DatasourceType.WEBSITE.value, | |||
| website_info={ | |||
| "provider": website_info_list["provider"], | |||
| "job_id": website_info_list["job_id"], | |||
| @@ -40,6 +40,7 @@ from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.plugin.impl.exc import PluginDaemonClientSideError | |||
| from core.rag.extractor.entity.datasource_type import DatasourceType | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from extensions.ext_database import db | |||
| from fields.document_fields import ( | |||
| @@ -425,7 +426,7 @@ class DocumentIndexingEstimateApi(DocumentResource): | |||
| raise NotFound("File not found.") | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", upload_file=file, document_model=document.doc_form | |||
| datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form | |||
| ) | |||
| indexing_runner = IndexingRunner() | |||
| @@ -485,13 +486,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| raise NotFound("File not found.") | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form | |||
| datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| elif document.data_source_type == "notion_import": | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| datasource_type=DatasourceType.NOTION.value, | |||
| notion_info={ | |||
| "notion_workspace_id": data_source_info["notion_workspace_id"], | |||
| "notion_obj_id": data_source_info["notion_page_id"], | |||
| @@ -503,7 +504,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| extract_settings.append(extract_setting) | |||
| elif document.data_source_type == "website_crawl": | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="website_crawl", | |||
| datasource_type=DatasourceType.WEBSITE.value, | |||
| website_info={ | |||
| "provider": data_source_info["provider"], | |||
| "job_id": data_source_info["job_id"], | |||
| @@ -19,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.cleaner.clean_processor import CleanProcessor | |||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||
| from core.rag.docstore.dataset_docstore import DatasetDocumentStore | |||
| from core.rag.extractor.entity.datasource_type import DatasourceType | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| @@ -340,7 +341,9 @@ class IndexingRunner: | |||
| if file_detail: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form | |||
| datasource_type=DatasourceType.FILE.value, | |||
| upload_file=file_detail, | |||
| document_model=dataset_document.doc_form, | |||
| ) | |||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) | |||
| elif dataset_document.data_source_type == "notion_import": | |||
| @@ -351,7 +354,7 @@ class IndexingRunner: | |||
| ): | |||
| raise ValueError("no notion import info found") | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| datasource_type=DatasourceType.NOTION.value, | |||
| notion_info={ | |||
| "notion_workspace_id": data_source_info["notion_workspace_id"], | |||
| "notion_obj_id": data_source_info["notion_page_id"], | |||
| @@ -371,7 +374,7 @@ class IndexingRunner: | |||
| ): | |||
| raise ValueError("no website import info found") | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="website_crawl", | |||
| datasource_type=DatasourceType.WEBSITE.value, | |||
| website_info={ | |||
| "provider": data_source_info["provider"], | |||
| "job_id": data_source_info["job_id"], | |||
| @@ -45,7 +45,7 @@ class ExtractProcessor: | |||
| cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False | |||
| ) -> Union[list[Document], str]: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", upload_file=upload_file, document_model="text_model" | |||
| datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model" | |||
| ) | |||
| if return_text: | |||
| delimiter = "\n" | |||
| @@ -76,7 +76,7 @@ class ExtractProcessor: | |||
| # https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521 | |||
| file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}" | |||
| Path(file_path).write_bytes(response.content) | |||
| extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") | |||
| extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model") | |||
| if return_text: | |||
| delimiter = "\n" | |||
| return delimiter.join( | |||
| @@ -87,7 +87,7 @@ class ClickZettaVolumeConfig(BaseModel): | |||
| values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME")) | |||
| values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_")) | |||
| values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km")) | |||
| # 暂时禁用权限检查功能,直接设置为false | |||
| # Temporarily disable permission check feature, set directly to false | |||
| values.setdefault("permission_check", False) | |||
| # Validate required fields | |||
| @@ -1,7 +1,7 @@ | |||
| """ClickZetta Volume文件生命周期管理 | |||
| """ClickZetta Volume file lifecycle management | |||
| 该模块提供文件版本控制、自动清理、备份和恢复等生命周期管理功能。 | |||
| 支持知识库文件的完整生命周期管理。 | |||
| This module provides file lifecycle management features including version control, automatic cleanup, backup and restore. | |||
| Supports complete lifecycle management for knowledge base files. | |||
| """ | |||
| import json | |||
| @@ -15,17 +15,17 @@ logger = logging.getLogger(__name__) | |||
| class FileStatus(Enum): | |||
| """文件状态枚举""" | |||
| """File status enumeration""" | |||
| ACTIVE = "active" # 活跃状态 | |||
| ARCHIVED = "archived" # 已归档 | |||
| DELETED = "deleted" # 已删除(软删除) | |||
| BACKUP = "backup" # 备份文件 | |||
| ACTIVE = "active" # Active status | |||
| ARCHIVED = "archived" # Archived | |||
| DELETED = "deleted" # Deleted (soft delete) | |||
| BACKUP = "backup" # Backup file | |||
| @dataclass | |||
| class FileMetadata: | |||
| """文件元数据""" | |||
| """File metadata""" | |||
| filename: str | |||
| size: int | None | |||
| @@ -38,7 +38,7 @@ class FileMetadata: | |||
| parent_version: Optional[int] = None | |||
| def to_dict(self) -> dict: | |||
| """转换为字典格式""" | |||
| """Convert to dictionary format""" | |||
| data = asdict(self) | |||
| data["created_at"] = self.created_at.isoformat() | |||
| data["modified_at"] = self.modified_at.isoformat() | |||
| @@ -47,7 +47,7 @@ class FileMetadata: | |||
| @classmethod | |||
| def from_dict(cls, data: dict) -> "FileMetadata": | |||
| """从字典创建实例""" | |||
| """Create instance from dictionary""" | |||
| data = data.copy() | |||
| data["created_at"] = datetime.fromisoformat(data["created_at"]) | |||
| data["modified_at"] = datetime.fromisoformat(data["modified_at"]) | |||
| @@ -56,14 +56,14 @@ class FileMetadata: | |||
| class FileLifecycleManager: | |||
| """文件生命周期管理器""" | |||
| """File lifecycle manager""" | |||
| def __init__(self, storage, dataset_id: Optional[str] = None): | |||
| """初始化生命周期管理器 | |||
| """Initialize lifecycle manager | |||
| Args: | |||
| storage: ClickZetta Volume存储实例 | |||
| dataset_id: 数据集ID(用于Table Volume) | |||
| storage: ClickZetta Volume storage instance | |||
| dataset_id: Dataset ID (for Table Volume) | |||
| """ | |||
| self._storage = storage | |||
| self._dataset_id = dataset_id | |||
| @@ -72,21 +72,21 @@ class FileLifecycleManager: | |||
| self._backup_prefix = ".backups/" | |||
| self._deleted_prefix = ".deleted/" | |||
| # 获取权限管理器(如果存在) | |||
| # Get permission manager (if exists) | |||
| self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None) | |||
| def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata: | |||
| """保存文件并管理生命周期 | |||
| """Save file and manage lifecycle | |||
| Args: | |||
| filename: 文件名 | |||
| data: 文件内容 | |||
| tags: 文件标签 | |||
| filename: File name | |||
| data: File content | |||
| tags: File tags | |||
| Returns: | |||
| 文件元数据 | |||
| File metadata | |||
| """ | |||
| # 权限检查 | |||
| # Permission check | |||
| if not self._check_permission(filename, "save"): | |||
| from .volume_permissions import VolumePermissionError | |||
| @@ -98,28 +98,28 @@ class FileLifecycleManager: | |||
| ) | |||
| try: | |||
| # 1. 检查是否存在旧版本 | |||
| # 1. Check if old version exists | |||
| metadata_dict = self._load_metadata() | |||
| current_metadata = metadata_dict.get(filename) | |||
| # 2. 如果存在旧版本,创建版本备份 | |||
| # 2. If old version exists, create version backup | |||
| if current_metadata: | |||
| self._create_version_backup(filename, current_metadata) | |||
| # 3. 计算文件信息 | |||
| # 3. Calculate file information | |||
| now = datetime.now() | |||
| checksum = self._calculate_checksum(data) | |||
| new_version = (current_metadata["version"] + 1) if current_metadata else 1 | |||
| # 4. 保存新文件 | |||
| # 4. Save new file | |||
| self._storage.save(filename, data) | |||
| # 5. 创建元数据 | |||
| # 5. Create metadata | |||
| created_at = now | |||
| parent_version = None | |||
| if current_metadata: | |||
| # 如果created_at是字符串,转换为datetime | |||
| # If created_at is string, convert to datetime | |||
| if isinstance(current_metadata["created_at"], str): | |||
| created_at = datetime.fromisoformat(current_metadata["created_at"]) | |||
| else: | |||
| @@ -138,7 +138,7 @@ class FileLifecycleManager: | |||
| parent_version=parent_version, | |||
| ) | |||
| # 6. 更新元数据 | |||
| # 6. Update metadata | |||
| metadata_dict[filename] = file_metadata.to_dict() | |||
| self._save_metadata(metadata_dict) | |||
| @@ -150,13 +150,13 @@ class FileLifecycleManager: | |||
| raise | |||
| def get_file_metadata(self, filename: str) -> Optional[FileMetadata]: | |||
| """获取文件元数据 | |||
| """Get file metadata | |||
| Args: | |||
| filename: 文件名 | |||
| filename: File name | |||
| Returns: | |||
| 文件元数据,如果不存在返回None | |||
| File metadata, returns None if not exists | |||
| """ | |||
| try: | |||
| metadata_dict = self._load_metadata() | |||
| @@ -168,37 +168,37 @@ class FileLifecycleManager: | |||
| return None | |||
| def list_file_versions(self, filename: str) -> list[FileMetadata]: | |||
| """列出文件的所有版本 | |||
| """List all versions of a file | |||
| Args: | |||
| filename: 文件名 | |||
| filename: File name | |||
| Returns: | |||
| 文件版本列表,按版本号排序 | |||
| File version list, sorted by version number | |||
| """ | |||
| try: | |||
| versions = [] | |||
| # 获取当前版本 | |||
| # Get current version | |||
| current_metadata = self.get_file_metadata(filename) | |||
| if current_metadata: | |||
| versions.append(current_metadata) | |||
| # 获取历史版本 | |||
| # Get historical versions | |||
| try: | |||
| version_files = self._storage.scan(self._dataset_id or "", files=True) | |||
| for file_path in version_files: | |||
| if file_path.startswith(f"{self._version_prefix}{filename}.v"): | |||
| # 解析版本号 | |||
| # Parse version number | |||
| version_str = file_path.split(".v")[-1].split(".")[0] | |||
| try: | |||
| version_num = int(version_str) | |||
| # 这里简化处理,实际应该从版本文件中读取元数据 | |||
| # 暂时创建基本的元数据信息 | |||
| # Simplified processing here, should actually read metadata from version file | |||
| # Temporarily create basic metadata information | |||
| except ValueError: | |||
| continue | |||
| except: | |||
| # 如果无法扫描版本文件,只返回当前版本 | |||
| # If cannot scan version files, only return current version | |||
| pass | |||
| return sorted(versions, key=lambda x: x.version or 0, reverse=True) | |||
| @@ -208,32 +208,32 @@ class FileLifecycleManager: | |||
| return [] | |||
| def restore_version(self, filename: str, version: int) -> bool: | |||
| """恢复文件到指定版本 | |||
| """Restore file to specified version | |||
| Args: | |||
| filename: 文件名 | |||
| version: 要恢复的版本号 | |||
| filename: File name | |||
| version: Version number to restore | |||
| Returns: | |||
| 恢复是否成功 | |||
| Whether restore succeeded | |||
| """ | |||
| try: | |||
| version_filename = f"{self._version_prefix}{filename}.v{version}" | |||
| # 检查版本文件是否存在 | |||
| # Check if version file exists | |||
| if not self._storage.exists(version_filename): | |||
| logger.warning("Version %s of %s not found", version, filename) | |||
| return False | |||
| # 读取版本文件内容 | |||
| # Read version file content | |||
| version_data = self._storage.load_once(version_filename) | |||
| # 保存当前版本为备份 | |||
| # Save current version as backup | |||
| current_metadata = self.get_file_metadata(filename) | |||
| if current_metadata: | |||
| self._create_version_backup(filename, current_metadata.to_dict()) | |||
| # 恢复文件 | |||
| # Restore file | |||
| self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)}) | |||
| return True | |||
| @@ -242,21 +242,21 @@ class FileLifecycleManager: | |||
| return False | |||
| def archive_file(self, filename: str) -> bool: | |||
| """归档文件 | |||
| """Archive file | |||
| Args: | |||
| filename: 文件名 | |||
| filename: File name | |||
| Returns: | |||
| 归档是否成功 | |||
| Whether archive succeeded | |||
| """ | |||
| # 权限检查 | |||
| # Permission check | |||
| if not self._check_permission(filename, "archive"): | |||
| logger.warning("Permission denied for archive operation on file: %s", filename) | |||
| return False | |||
| try: | |||
| # 更新文件状态为归档 | |||
| # Update file status to archived | |||
| metadata_dict = self._load_metadata() | |||
| if filename not in metadata_dict: | |||
| logger.warning("File %s not found in metadata", filename) | |||
| @@ -275,36 +275,36 @@ class FileLifecycleManager: | |||
| return False | |||
| def soft_delete_file(self, filename: str) -> bool: | |||
| """软删除文件(移动到删除目录) | |||
| """Soft delete file (move to deleted directory) | |||
| Args: | |||
| filename: 文件名 | |||
| filename: File name | |||
| Returns: | |||
| 删除是否成功 | |||
| Whether delete succeeded | |||
| """ | |||
| # 权限检查 | |||
| # Permission check | |||
| if not self._check_permission(filename, "delete"): | |||
| logger.warning("Permission denied for soft delete operation on file: %s", filename) | |||
| return False | |||
| try: | |||
| # 检查文件是否存在 | |||
| # Check if file exists | |||
| if not self._storage.exists(filename): | |||
| logger.warning("File %s not found", filename) | |||
| return False | |||
| # 读取文件内容 | |||
| # Read file content | |||
| file_data = self._storage.load_once(filename) | |||
| # 移动到删除目录 | |||
| # Move to deleted directory | |||
| deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |||
| self._storage.save(deleted_filename, file_data) | |||
| # 删除原文件 | |||
| # Delete original file | |||
| self._storage.delete(filename) | |||
| # 更新元数据 | |||
| # Update metadata | |||
| metadata_dict = self._load_metadata() | |||
| if filename in metadata_dict: | |||
| metadata_dict[filename]["status"] = FileStatus.DELETED.value | |||
| @@ -319,27 +319,27 @@ class FileLifecycleManager: | |||
| return False | |||
| def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int: | |||
| """清理旧版本文件 | |||
| """Cleanup old version files | |||
| Args: | |||
| max_versions: 保留的最大版本数 | |||
| max_age_days: 版本文件的最大保留天数 | |||
| max_versions: Maximum number of versions to keep | |||
| max_age_days: Maximum retention days for version files | |||
| Returns: | |||
| 清理的文件数量 | |||
| Number of files cleaned | |||
| """ | |||
| try: | |||
| cleaned_count = 0 | |||
| # 获取所有版本文件 | |||
| # Get all version files | |||
| try: | |||
| all_files = self._storage.scan(self._dataset_id or "", files=True) | |||
| version_files = [f for f in all_files if f.startswith(self._version_prefix)] | |||
| # 按文件分组 | |||
| # Group by file | |||
| file_versions: dict[str, list[tuple[int, str]]] = {} | |||
| for version_file in version_files: | |||
| # 解析文件名和版本 | |||
| # Parse filename and version | |||
| parts = version_file[len(self._version_prefix) :].split(".v") | |||
| if len(parts) >= 2: | |||
| base_filename = parts[0] | |||
| @@ -352,12 +352,12 @@ class FileLifecycleManager: | |||
| except ValueError: | |||
| continue | |||
| # 清理每个文件的旧版本 | |||
| # Cleanup old versions for each file | |||
| for base_filename, versions in file_versions.items(): | |||
| # 按版本号排序 | |||
| # Sort by version number | |||
| versions.sort(key=lambda x: x[0], reverse=True) | |||
| # 保留最新的max_versions个版本,删除其余的 | |||
| # Keep the newest max_versions versions, delete the rest | |||
| if len(versions) > max_versions: | |||
| to_delete = versions[max_versions:] | |||
| for version_num, version_file in to_delete: | |||
| @@ -377,10 +377,10 @@ class FileLifecycleManager: | |||
| return 0 | |||
| def get_storage_statistics(self) -> dict[str, Any]: | |||
| """获取存储统计信息 | |||
| """Get storage statistics | |||
| Returns: | |||
| 存储统计字典 | |||
| Storage statistics dictionary | |||
| """ | |||
| try: | |||
| metadata_dict = self._load_metadata() | |||
| @@ -402,7 +402,7 @@ class FileLifecycleManager: | |||
| for filename, metadata in metadata_dict.items(): | |||
| file_meta = FileMetadata.from_dict(metadata) | |||
| # 统计文件状态 | |||
| # Count file status | |||
| if file_meta.status == FileStatus.ACTIVE: | |||
| stats["active_files"] = (stats["active_files"] or 0) + 1 | |||
| elif file_meta.status == FileStatus.ARCHIVED: | |||
| @@ -410,13 +410,13 @@ class FileLifecycleManager: | |||
| elif file_meta.status == FileStatus.DELETED: | |||
| stats["deleted_files"] = (stats["deleted_files"] or 0) + 1 | |||
| # 统计大小 | |||
| # Count size | |||
| stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0) | |||
| # 统计版本 | |||
| # Count versions | |||
| stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0) | |||
| # 找出最新和最旧的文件 | |||
| # Find newest and oldest files | |||
| if oldest_date is None or file_meta.created_at < oldest_date: | |||
| oldest_date = file_meta.created_at | |||
| stats["oldest_file"] = filename | |||
| @@ -432,12 +432,12 @@ class FileLifecycleManager: | |||
| return {} | |||
| def _create_version_backup(self, filename: str, metadata: dict): | |||
| """创建版本备份""" | |||
| """Create version backup""" | |||
| try: | |||
| # 读取当前文件内容 | |||
| # Read current file content | |||
| current_data = self._storage.load_once(filename) | |||
| # 保存为版本文件 | |||
| # Save as version file | |||
| version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}" | |||
| self._storage.save(version_filename, current_data) | |||
| @@ -447,7 +447,7 @@ class FileLifecycleManager: | |||
| logger.warning("Failed to create version backup for %s: %s", filename, e) | |||
| def _load_metadata(self) -> dict[str, Any]: | |||
| """加载元数据文件""" | |||
| """Load metadata file""" | |||
| try: | |||
| if self._storage.exists(self._metadata_file): | |||
| metadata_content = self._storage.load_once(self._metadata_file) | |||
| @@ -460,7 +460,7 @@ class FileLifecycleManager: | |||
| return {} | |||
| def _save_metadata(self, metadata_dict: dict): | |||
| """保存元数据文件""" | |||
| """Save metadata file""" | |||
| try: | |||
| metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False) | |||
| self._storage.save(self._metadata_file, metadata_content.encode("utf-8")) | |||
| @@ -470,45 +470,45 @@ class FileLifecycleManager: | |||
| raise | |||
| def _calculate_checksum(self, data: bytes) -> str: | |||
| """计算文件校验和""" | |||
| """Calculate file checksum""" | |||
| import hashlib | |||
| return hashlib.md5(data).hexdigest() | |||
| def _check_permission(self, filename: str, operation: str) -> bool: | |||
| """检查文件操作权限 | |||
| """Check file operation permission | |||
| Args: | |||
| filename: 文件名 | |||
| operation: 操作类型 | |||
| filename: File name | |||
| operation: Operation type | |||
| Returns: | |||
| True if permission granted, False otherwise | |||
| """ | |||
| # 如果没有权限管理器,默认允许 | |||
| # If no permission manager, allow by default | |||
| if not self._permission_manager: | |||
| return True | |||
| try: | |||
| # 根据操作类型映射到权限 | |||
| # Map operation type to permission | |||
| operation_mapping = { | |||
| "save": "save", | |||
| "load": "load_once", | |||
| "delete": "delete", | |||
| "archive": "delete", # 归档需要删除权限 | |||
| "restore": "save", # 恢复需要写权限 | |||
| "cleanup": "delete", # 清理需要删除权限 | |||
| "archive": "delete", # Archive requires delete permission | |||
| "restore": "save", # Restore requires write permission | |||
| "cleanup": "delete", # Cleanup requires delete permission | |||
| "read": "load_once", | |||
| "write": "save", | |||
| } | |||
| mapped_operation = operation_mapping.get(operation, operation) | |||
| # 检查权限 | |||
| # Check permission | |||
| result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id) | |||
| return bool(result) | |||
| except Exception as e: | |||
| logger.exception("Permission check failed for %s operation %s", filename, operation) | |||
| # 安全默认:权限检查失败时拒绝访问 | |||
| # Safe default: deny access when permission check fails | |||
| return False | |||
| @@ -1,7 +1,7 @@ | |||
| """ClickZetta Volume权限管理机制 | |||
| """ClickZetta Volume permission management mechanism | |||
| 该模块提供Volume权限检查、验证和管理功能。 | |||
| 根据ClickZetta的权限模型,不同Volume类型有不同的权限要求。 | |||
| This module provides Volume permission checking, validation and management features. | |||
| According to ClickZetta's permission model, different Volume types have different permission requirements. | |||
| """ | |||
| import logging | |||
| @@ -12,29 +12,29 @@ logger = logging.getLogger(__name__) | |||
| class VolumePermission(Enum): | |||
| """Volume权限类型枚举""" | |||
| """Volume permission type enumeration""" | |||
| READ = "SELECT" # 对应ClickZetta的SELECT权限 | |||
| WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限 | |||
| LIST = "SELECT" # 列出文件需要SELECT权限 | |||
| DELETE = "INSERT,UPDATE,DELETE" # 删除文件需要写权限 | |||
| USAGE = "USAGE" # External Volume需要的基本权限 | |||
| READ = "SELECT" # Corresponds to ClickZetta's SELECT permission | |||
| WRITE = "INSERT,UPDATE,DELETE" # Corresponds to ClickZetta's write permissions | |||
| LIST = "SELECT" # Listing files requires SELECT permission | |||
| DELETE = "INSERT,UPDATE,DELETE" # Deleting files requires write permissions | |||
| USAGE = "USAGE" # Basic permission required for External Volume | |||
| class VolumePermissionManager: | |||
| """Volume权限管理器""" | |||
| """Volume permission manager""" | |||
| def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None): | |||
| """初始化权限管理器 | |||
| """Initialize permission manager | |||
| Args: | |||
| connection_or_config: ClickZetta连接对象或配置字典 | |||
| volume_type: Volume类型 (user|table|external) | |||
| volume_name: Volume名称 (用于external volume) | |||
| connection_or_config: ClickZetta connection object or configuration dictionary | |||
| volume_type: Volume type (user|table|external) | |||
| volume_name: Volume name (for external volume) | |||
| """ | |||
| # 支持两种初始化方式:连接对象或配置字典 | |||
| # Support two initialization methods: connection object or configuration dictionary | |||
| if isinstance(connection_or_config, dict): | |||
| # 从配置字典创建连接 | |||
| # Create connection from configuration dictionary | |||
| import clickzetta # type: ignore[import-untyped] | |||
| config = connection_or_config | |||
| @@ -50,7 +50,7 @@ class VolumePermissionManager: | |||
| self._volume_type = config.get("volume_type", volume_type) | |||
| self._volume_name = config.get("volume_name", volume_name) | |||
| else: | |||
| # 直接使用连接对象 | |||
| # Use connection object directly | |||
| self._connection = connection_or_config | |||
| self._volume_type = volume_type | |||
| self._volume_name = volume_name | |||
| @@ -61,14 +61,14 @@ class VolumePermissionManager: | |||
| raise ValueError("volume_type is required") | |||
| self._permission_cache: dict[str, set[str]] = {} | |||
| self._current_username = None # 将从连接中获取当前用户名 | |||
| self._current_username = None # Will get current username from connection | |||
| def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool: | |||
| """检查用户是否有执行特定操作的权限 | |||
| """Check if user has permission to perform specific operation | |||
| Args: | |||
| operation: 要执行的操作类型 | |||
| dataset_id: 数据集ID (用于table volume) | |||
| operation: Type of operation to perform | |||
| dataset_id: Dataset ID (for table volume) | |||
| Returns: | |||
| True if user has permission, False otherwise | |||
| @@ -89,20 +89,20 @@ class VolumePermissionManager: | |||
| return False | |||
| def _check_user_volume_permission(self, operation: VolumePermission) -> bool: | |||
| """检查User Volume权限 | |||
| """Check User Volume permission | |||
| User Volume权限规则: | |||
| - 用户对自己的User Volume有全部权限 | |||
| - 只要用户能够连接到ClickZetta,就默认具有User Volume的基本权限 | |||
| - 更注重连接身份验证,而不是复杂的权限检查 | |||
| User Volume permission rules: | |||
| - User has full permissions on their own User Volume | |||
| - As long as user can connect to ClickZetta, they have basic User Volume permissions by default | |||
| - Focus more on connection authentication rather than complex permission checking | |||
| """ | |||
| try: | |||
| # 获取当前用户名 | |||
| # Get current username | |||
| current_user = self._get_current_username() | |||
| # 检查基本连接状态 | |||
| # Check basic connection status | |||
| with self._connection.cursor() as cursor: | |||
| # 简单的连接测试,如果能执行查询说明用户有基本权限 | |||
| # Simple connection test, if query can be executed user has basic permissions | |||
| cursor.execute("SELECT 1") | |||
| result = cursor.fetchone() | |||
| @@ -121,17 +121,18 @@ class VolumePermissionManager: | |||
| except Exception as e: | |||
| logger.exception("User Volume permission check failed") | |||
| # 对于User Volume,如果权限检查失败,可能是配置问题,给出更友好的错误提示 | |||
| # For User Volume, if permission check fails, it might be a configuration issue, | |||
| # provide friendlier error message | |||
| logger.info("User Volume permission check failed, but permission checking is disabled in this version") | |||
| return False | |||
| def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool: | |||
| """检查Table Volume权限 | |||
| """Check Table Volume permission | |||
| Table Volume权限规则: | |||
| - Table Volume权限继承对应表的权限 | |||
| - SELECT权限 -> 可以READ/LIST文件 | |||
| - INSERT,UPDATE,DELETE权限 -> 可以WRITE/DELETE文件 | |||
| Table Volume permission rules: | |||
| - Table Volume permissions inherit from corresponding table permissions | |||
| - SELECT permission -> can READ/LIST files | |||
| - INSERT,UPDATE,DELETE permissions -> can WRITE/DELETE files | |||
| """ | |||
| if not dataset_id: | |||
| logger.warning("dataset_id is required for table volume permission check") | |||
| @@ -140,11 +141,11 @@ class VolumePermissionManager: | |||
| table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id | |||
| try: | |||
| # 检查表权限 | |||
| # Check table permissions | |||
| permissions = self._get_table_permissions(table_name) | |||
| required_permissions = set(operation.value.split(",")) | |||
| # 检查是否有所需的所有权限 | |||
| # Check if has all required permissions | |||
| has_permission = required_permissions.issubset(permissions) | |||
| logger.debug( | |||
| @@ -163,22 +164,22 @@ class VolumePermissionManager: | |||
| return False | |||
| def _check_external_volume_permission(self, operation: VolumePermission) -> bool: | |||
| """检查External Volume权限 | |||
| """Check External Volume permission | |||
| External Volume权限规则: | |||
| - 尝试获取对External Volume的权限 | |||
| - 如果权限检查失败,进行备选验证 | |||
| - 对于开发环境,提供更宽松的权限检查 | |||
| External Volume permission rules: | |||
| - Try to get permissions for External Volume | |||
| - If permission check fails, perform fallback verification | |||
| - For development environment, provide more lenient permission checking | |||
| """ | |||
| if not self._volume_name: | |||
| logger.warning("volume_name is required for external volume permission check") | |||
| return False | |||
| try: | |||
| # 检查External Volume权限 | |||
| # Check External Volume permissions | |||
| permissions = self._get_external_volume_permissions(self._volume_name) | |||
| # External Volume权限映射:根据操作类型确定所需权限 | |||
| # External Volume permission mapping: determine required permissions based on operation type | |||
| required_permissions = set() | |||
| if operation in [VolumePermission.READ, VolumePermission.LIST]: | |||
| @@ -186,7 +187,7 @@ class VolumePermissionManager: | |||
| elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]: | |||
| required_permissions.add("write") | |||
| # 检查是否有所需的所有权限 | |||
| # Check if has all required permissions | |||
| has_permission = required_permissions.issubset(permissions) | |||
| logger.debug( | |||
| @@ -198,11 +199,11 @@ class VolumePermissionManager: | |||
| has_permission, | |||
| ) | |||
| # 如果权限检查失败,尝试备选验证 | |||
| # If permission check fails, try fallback verification | |||
| if not has_permission: | |||
| logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name) | |||
| # 备选验证:尝试列出Volume来验证基本访问权限 | |||
| # Fallback verification: try listing Volume to verify basic access permissions | |||
| try: | |||
| with self._connection.cursor() as cursor: | |||
| cursor.execute("SHOW VOLUMES") | |||
| @@ -222,13 +223,13 @@ class VolumePermissionManager: | |||
| return False | |||
| def _get_table_permissions(self, table_name: str) -> set[str]: | |||
| """获取用户对指定表的权限 | |||
| """Get user permissions for specified table | |||
| Args: | |||
| table_name: 表名 | |||
| table_name: Table name | |||
| Returns: | |||
| 用户对该表的权限集合 | |||
| Set of user permissions for this table | |||
| """ | |||
| cache_key = f"table:{table_name}" | |||
| @@ -239,18 +240,18 @@ class VolumePermissionManager: | |||
| try: | |||
| with self._connection.cursor() as cursor: | |||
| # 使用正确的ClickZetta语法检查当前用户权限 | |||
| # Use correct ClickZetta syntax to check current user permissions | |||
| cursor.execute("SHOW GRANTS") | |||
| grants = cursor.fetchall() | |||
| # 解析权限结果,查找对该表的权限 | |||
| # Parse permission results, find permissions for this table | |||
| for grant in grants: | |||
| if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) | |||
| if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...) | |||
| privilege = grant[0].upper() | |||
| object_type = grant[1].upper() if len(grant) > 1 else "" | |||
| object_name = grant[2] if len(grant) > 2 else "" | |||
| # 检查是否是对该表的权限 | |||
| # Check if it's permission for this table | |||
| if ( | |||
| object_type == "TABLE" | |||
| and object_name == table_name | |||
| @@ -263,7 +264,7 @@ class VolumePermissionManager: | |||
| else: | |||
| permissions.add(privilege) | |||
| # 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限 | |||
| # If no explicit permissions found, try executing a simple query to verify permissions | |||
| if not permissions: | |||
| try: | |||
| cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1") | |||
| @@ -273,15 +274,15 @@ class VolumePermissionManager: | |||
| except Exception as e: | |||
| logger.warning("Could not check table permissions for %s: %s", table_name, e) | |||
| # 安全默认:权限检查失败时拒绝访问 | |||
| # Safe default: deny access when permission check fails | |||
| pass | |||
| # 缓存权限信息 | |||
| # Cache permission information | |||
| self._permission_cache[cache_key] = permissions | |||
| return permissions | |||
| def _get_current_username(self) -> str: | |||
| """获取当前用户名""" | |||
| """Get current username""" | |||
| if self._current_username: | |||
| return self._current_username | |||
| @@ -298,7 +299,7 @@ class VolumePermissionManager: | |||
| return "unknown" | |||
| def _get_user_permissions(self, username: str) -> set[str]: | |||
| """获取用户的基本权限集合""" | |||
| """Get user's basic permission set""" | |||
| cache_key = f"user_permissions:{username}" | |||
| if cache_key in self._permission_cache: | |||
| @@ -308,17 +309,17 @@ class VolumePermissionManager: | |||
| try: | |||
| with self._connection.cursor() as cursor: | |||
| # 使用正确的ClickZetta语法检查当前用户权限 | |||
| # Use correct ClickZetta syntax to check current user permissions | |||
| cursor.execute("SHOW GRANTS") | |||
| grants = cursor.fetchall() | |||
| # 解析权限结果,查找用户的基本权限 | |||
| # Parse permission results, find user's basic permissions | |||
| for grant in grants: | |||
| if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) | |||
| if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...) | |||
| privilege = grant[0].upper() | |||
| object_type = grant[1].upper() if len(grant) > 1 else "" | |||
| # 收集所有相关权限 | |||
| # Collect all relevant permissions | |||
| if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: | |||
| if privilege == "ALL": | |||
| permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) | |||
| @@ -327,21 +328,21 @@ class VolumePermissionManager: | |||
| except Exception as e: | |||
| logger.warning("Could not check user permissions for %s: %s", username, e) | |||
| # 安全默认:权限检查失败时拒绝访问 | |||
| # Safe default: deny access when permission check fails | |||
| pass | |||
| # 缓存权限信息 | |||
| # Cache permission information | |||
| self._permission_cache[cache_key] = permissions | |||
| return permissions | |||
| def _get_external_volume_permissions(self, volume_name: str) -> set[str]: | |||
| """获取用户对指定External Volume的权限 | |||
| """Get user permissions for specified External Volume | |||
| Args: | |||
| volume_name: External Volume名称 | |||
| volume_name: External Volume name | |||
| Returns: | |||
| 用户对该Volume的权限集合 | |||
| Set of user permissions for this Volume | |||
| """ | |||
| cache_key = f"external_volume:{volume_name}" | |||
| @@ -352,15 +353,15 @@ class VolumePermissionManager: | |||
| try: | |||
| with self._connection.cursor() as cursor: | |||
| # 使用正确的ClickZetta语法检查Volume权限 | |||
| # Use correct ClickZetta syntax to check Volume permissions | |||
| logger.info("Checking permissions for volume: %s", volume_name) | |||
| cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}") | |||
| grants = cursor.fetchall() | |||
| logger.info("Raw grants result for %s: %s", volume_name, grants) | |||
| # 解析权限结果 | |||
| # 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to, | |||
| # Parse permission results | |||
| # Format: (granted_type, privilege, conditions, granted_on, object_name, granted_to, | |||
| # grantee_name, grantor_name, grant_option, granted_time) | |||
| for grant in grants: | |||
| logger.info("Processing grant: %s", grant) | |||
| @@ -378,7 +379,7 @@ class VolumePermissionManager: | |||
| object_name, | |||
| ) | |||
| # 检查是否是对该Volume的权限或者是层级权限 | |||
| # Check if it's permission for this Volume or hierarchical permission | |||
| if ( | |||
| granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name) | |||
| ) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"): | |||
| @@ -399,14 +400,14 @@ class VolumePermissionManager: | |||
| logger.info("Final permissions for %s: %s", volume_name, permissions) | |||
| # 如果没有找到明确的权限,尝试查看Volume列表来验证基本权限 | |||
| # If no explicit permissions found, try viewing Volume list to verify basic permissions | |||
| if not permissions: | |||
| try: | |||
| cursor.execute("SHOW VOLUMES") | |||
| volumes = cursor.fetchall() | |||
| for volume in volumes: | |||
| if len(volume) > 0 and volume[0] == volume_name: | |||
| permissions.add("read") # 至少有读权限 | |||
| permissions.add("read") # At least has read permission | |||
| logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name) | |||
| break | |||
| except Exception: | |||
| @@ -414,7 +415,7 @@ class VolumePermissionManager: | |||
| except Exception as e: | |||
| logger.warning("Could not check external volume permissions for %s: %s", volume_name, e) | |||
| # 在权限检查失败时,尝试基本的Volume访问验证 | |||
| # When permission check fails, try basic Volume access verification | |||
| try: | |||
| with self._connection.cursor() as cursor: | |||
| cursor.execute("SHOW VOLUMES") | |||
| @@ -423,30 +424,30 @@ class VolumePermissionManager: | |||
| if len(volume) > 0 and volume[0] == volume_name: | |||
| logger.info("Basic volume access verified for %s", volume_name) | |||
| permissions.add("read") | |||
| permissions.add("write") # 假设有写权限 | |||
| permissions.add("write") # Assume has write permission | |||
| break | |||
| except Exception as basic_e: | |||
| logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e) | |||
| # 最后的备选方案:假设有基本权限 | |||
| # Last fallback: assume basic permissions | |||
| permissions.add("read") | |||
| # 缓存权限信息 | |||
| # Cache permission information | |||
| self._permission_cache[cache_key] = permissions | |||
| return permissions | |||
| def clear_permission_cache(self): | |||
| """清空权限缓存""" | |||
| """Clear permission cache""" | |||
| self._permission_cache.clear() | |||
| logger.debug("Permission cache cleared") | |||
| def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]: | |||
| """获取权限摘要 | |||
| """Get permission summary | |||
| Args: | |||
| dataset_id: 数据集ID (用于table volume) | |||
| dataset_id: Dataset ID (for table volume) | |||
| Returns: | |||
| 权限摘要字典 | |||
| Permission summary dictionary | |||
| """ | |||
| summary = {} | |||
| @@ -456,43 +457,43 @@ class VolumePermissionManager: | |||
| return summary | |||
| def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool: | |||
| """检查文件路径的权限继承 | |||
| """Check permission inheritance for file path | |||
| Args: | |||
| file_path: 文件路径 | |||
| operation: 要执行的操作 | |||
| file_path: File path | |||
| operation: Operation to perform | |||
| Returns: | |||
| True if user has permission, False otherwise | |||
| """ | |||
| try: | |||
| # 解析文件路径 | |||
| # Parse file path | |||
| path_parts = file_path.strip("/").split("/") | |||
| if not path_parts: | |||
| logger.warning("Invalid file path for permission inheritance check") | |||
| return False | |||
| # 对于Table Volume,第一层是dataset_id | |||
| # For Table Volume, first layer is dataset_id | |||
| if self._volume_type == "table": | |||
| if len(path_parts) < 1: | |||
| return False | |||
| dataset_id = path_parts[0] | |||
| # 检查对dataset的权限 | |||
| # Check permissions for dataset | |||
| has_dataset_permission = self.check_permission(operation, dataset_id) | |||
| if not has_dataset_permission: | |||
| logger.debug("Permission denied for dataset %s", dataset_id) | |||
| return False | |||
| # 检查路径遍历攻击 | |||
| # Check path traversal attack | |||
| if self._contains_path_traversal(file_path): | |||
| logger.warning("Path traversal attack detected: %s", file_path) | |||
| return False | |||
| # 检查是否访问敏感目录 | |||
| # Check if accessing sensitive directory | |||
| if self._is_sensitive_path(file_path): | |||
| logger.warning("Access to sensitive path denied: %s", file_path) | |||
| return False | |||
| @@ -501,20 +502,20 @@ class VolumePermissionManager: | |||
| return True | |||
| elif self._volume_type == "user": | |||
| # User Volume的权限继承 | |||
| # User Volume permission inheritance | |||
| current_user = self._get_current_username() | |||
| # 检查是否试图访问其他用户的目录 | |||
| # Check if attempting to access other user's directory | |||
| if len(path_parts) > 1 and path_parts[0] != current_user: | |||
| logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0]) | |||
| return False | |||
| # 检查基本权限 | |||
| # Check basic permissions | |||
| return self.check_permission(operation) | |||
| elif self._volume_type == "external": | |||
| # External Volume的权限继承 | |||
| # 检查对External Volume的权限 | |||
| # External Volume permission inheritance | |||
| # Check permissions for External Volume | |||
| return self.check_permission(operation) | |||
| else: | |||
| @@ -526,8 +527,8 @@ class VolumePermissionManager: | |||
| return False | |||
| def _contains_path_traversal(self, file_path: str) -> bool: | |||
| """检查路径是否包含路径遍历攻击""" | |||
| # 检查常见的路径遍历模式 | |||
| """Check if path contains path traversal attack""" | |||
| # Check common path traversal patterns | |||
| traversal_patterns = [ | |||
| "../", | |||
| "..\\", | |||
| @@ -547,18 +548,18 @@ class VolumePermissionManager: | |||
| if pattern in file_path_lower: | |||
| return True | |||
| # 检查绝对路径 | |||
| # Check absolute path | |||
| if file_path.startswith("/") or file_path.startswith("\\"): | |||
| return True | |||
| # 检查Windows驱动器路径 | |||
| # Check Windows drive path | |||
| if len(file_path) >= 2 and file_path[1] == ":": | |||
| return True | |||
| return False | |||
| def _is_sensitive_path(self, file_path: str) -> bool: | |||
| """检查路径是否为敏感路径""" | |||
| """Check if path is sensitive path""" | |||
| sensitive_patterns = [ | |||
| "passwd", | |||
| "shadow", | |||
| @@ -582,11 +583,11 @@ class VolumePermissionManager: | |||
| return any(pattern in file_path_lower for pattern in sensitive_patterns) | |||
| def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool: | |||
| """验证操作权限 | |||
| """Validate operation permission | |||
| Args: | |||
| operation: 操作名称 (save|load|exists|delete|scan) | |||
| dataset_id: 数据集ID | |||
| operation: Operation name (save|load|exists|delete|scan) | |||
| dataset_id: Dataset ID | |||
| Returns: | |||
| True if operation is allowed, False otherwise | |||
| @@ -611,7 +612,7 @@ class VolumePermissionManager: | |||
| class VolumePermissionError(Exception): | |||
| """Volume权限错误异常""" | |||
| """Volume permission error exception""" | |||
| def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None): | |||
| self.operation = operation | |||
| @@ -623,15 +624,15 @@ class VolumePermissionError(Exception): | |||
| def check_volume_permission( | |||
| permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None | |||
| ) -> None: | |||
| """权限检查装饰器函数 | |||
| """Permission check decorator function | |||
| Args: | |||
| permission_manager: 权限管理器 | |||
| operation: 操作名称 | |||
| dataset_id: 数据集ID | |||
| permission_manager: Permission manager | |||
| operation: Operation name | |||
| dataset_id: Dataset ID | |||
| Raises: | |||
| VolumePermissionError: 如果没有权限 | |||
| VolumePermissionError: If no permission | |||
| """ | |||
| if not permission_manager.validate_operation(operation, dataset_id): | |||
| error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume" | |||
| @@ -1,11 +0,0 @@ | |||
| from tests.integration_tests.utils.parent_class import ParentClass | |||
| class LazyLoadChildClass(ParentClass): | |||
| """Test lazy load child class for module import helper tests""" | |||
| def __init__(self, name): | |||
| super().__init__(name) | |||
| def get_name(self): | |||
| return self.name | |||
| @@ -0,0 +1,716 @@ | |||
| import json | |||
| from unittest.mock import patch | |||
| import pytest | |||
| from faker import Faker | |||
| from models.tools import WorkflowToolProvider | |||
| from models.workflow import Workflow as WorkflowModel | |||
| from services.account_service import AccountService, TenantService | |||
| from services.app_service import AppService | |||
| from services.tools.workflow_tools_manage_service import WorkflowToolManageService | |||
| class TestWorkflowToolManageService: | |||
| """Integration tests for WorkflowToolManageService using testcontainers.""" | |||
| @pytest.fixture | |||
| def mock_external_service_dependencies(self): | |||
| """Mock setup for external service dependencies.""" | |||
| with ( | |||
| patch("services.app_service.FeatureService") as mock_feature_service, | |||
| patch("services.app_service.EnterpriseService") as mock_enterprise_service, | |||
| patch("services.app_service.ModelManager") as mock_model_manager, | |||
| patch("services.account_service.FeatureService") as mock_account_feature_service, | |||
| patch( | |||
| "services.tools.workflow_tools_manage_service.WorkflowToolProviderController" | |||
| ) as mock_workflow_tool_provider_controller, | |||
| patch("services.tools.workflow_tools_manage_service.ToolLabelManager") as mock_tool_label_manager, | |||
| patch("services.tools.workflow_tools_manage_service.ToolTransformService") as mock_tool_transform_service, | |||
| ): | |||
| # Setup default mock returns for app service | |||
| mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False | |||
| mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None | |||
| mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None | |||
| # Setup default mock returns for account service | |||
| mock_account_feature_service.get_system_features.return_value.is_allow_register = True | |||
| # Mock ModelManager for model configuration | |||
| mock_model_instance = mock_model_manager.return_value | |||
| mock_model_instance.get_default_model_instance.return_value = None | |||
| mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") | |||
| # Mock WorkflowToolProviderController | |||
| mock_workflow_tool_provider_controller.from_db.return_value = None | |||
| # Mock ToolLabelManager | |||
| mock_tool_label_manager.update_tool_labels.return_value = None | |||
| # Mock ToolTransformService | |||
| mock_tool_transform_service.workflow_provider_to_controller.return_value = None | |||
| yield { | |||
| "feature_service": mock_feature_service, | |||
| "enterprise_service": mock_enterprise_service, | |||
| "model_manager": mock_model_manager, | |||
| "account_feature_service": mock_account_feature_service, | |||
| "workflow_tool_provider_controller": mock_workflow_tool_provider_controller, | |||
| "tool_label_manager": mock_tool_label_manager, | |||
| "tool_transform_service": mock_tool_transform_service, | |||
| } | |||
| def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): | |||
| """ | |||
| Helper method to create a test app and account for testing. | |||
| Args: | |||
| db_session_with_containers: Database session from testcontainers infrastructure | |||
| mock_external_service_dependencies: Mock dependencies | |||
| Returns: | |||
| tuple: (app, account, workflow) - Created app, account and workflow instances | |||
| """ | |||
| fake = Faker() | |||
| # Setup mocks for account creation | |||
| mock_external_service_dependencies[ | |||
| "account_feature_service" | |||
| ].get_system_features.return_value.is_allow_register = True | |||
| # Create account and tenant | |||
| account = AccountService.create_account( | |||
| email=fake.email(), | |||
| name=fake.name(), | |||
| interface_language="en-US", | |||
| password=fake.password(length=12), | |||
| ) | |||
| TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) | |||
| tenant = account.current_tenant | |||
| # Create app with realistic data | |||
| app_args = { | |||
| "name": fake.company(), | |||
| "description": fake.text(max_nb_chars=100), | |||
| "mode": "workflow", | |||
| "icon_type": "emoji", | |||
| "icon": "🤖", | |||
| "icon_background": "#FF6B6B", | |||
| "api_rph": 100, | |||
| "api_rpm": 10, | |||
| } | |||
| app_service = AppService() | |||
| app = app_service.create_app(tenant.id, app_args, account) | |||
| # Create workflow for the app | |||
| workflow = WorkflowModel( | |||
| tenant_id=tenant.id, | |||
| app_id=app.id, | |||
| type="workflow", | |||
| version="1.0.0", | |||
| graph=json.dumps({}), | |||
| features=json.dumps({}), | |||
| created_by=account.id, | |||
| environment_variables=[], | |||
| conversation_variables=[], | |||
| ) | |||
| from extensions.ext_database import db | |||
| db.session.add(workflow) | |||
| db.session.commit() | |||
| # Update app to reference the workflow | |||
| app.workflow_id = workflow.id | |||
| db.session.commit() | |||
| return app, account, workflow | |||
| def _create_test_workflow_tool_parameters(self): | |||
| """Helper method to create valid workflow tool parameters.""" | |||
| return [ | |||
| { | |||
| "name": "input_text", | |||
| "description": "Input text for processing", | |||
| "form": "form", | |||
| "type": "string", | |||
| "required": True, | |||
| }, | |||
| { | |||
| "name": "output_format", | |||
| "description": "Output format specification", | |||
| "form": "form", | |||
| "type": "select", | |||
| "required": False, | |||
| }, | |||
| ] | |||
| def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): | |||
| """ | |||
| Test successful workflow tool creation with valid parameters. | |||
| This test verifies: | |||
| - Proper workflow tool creation with all required fields | |||
| - Correct database state after creation | |||
| - Proper relationship establishment | |||
| - External service integration | |||
| - Return value correctness | |||
| """ | |||
| fake = Faker() | |||
| # Create test data | |||
| app, account, workflow = self._create_test_app_and_account( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| # Setup workflow tool creation parameters | |||
| tool_name = fake.word() | |||
| tool_label = fake.word() | |||
| tool_icon = {"type": "emoji", "emoji": "🔧"} | |||
| tool_description = fake.text(max_nb_chars=200) | |||
| tool_parameters = self._create_test_workflow_tool_parameters() | |||
| tool_privacy_policy = fake.text(max_nb_chars=100) | |||
| tool_labels = ["automation", "workflow"] | |||
| # Execute the method under test | |||
| result = WorkflowToolManageService.create_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_app_id=app.id, | |||
| name=tool_name, | |||
| label=tool_label, | |||
| icon=tool_icon, | |||
| description=tool_description, | |||
| parameters=tool_parameters, | |||
| privacy_policy=tool_privacy_policy, | |||
| labels=tool_labels, | |||
| ) | |||
| # Verify the result | |||
| assert result == {"result": "success"} | |||
| # Verify database state | |||
| from extensions.ext_database import db | |||
| # Check if workflow tool provider was created | |||
| created_tool_provider = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .where( | |||
| WorkflowToolProvider.tenant_id == account.current_tenant.id, | |||
| WorkflowToolProvider.app_id == app.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| assert created_tool_provider is not None | |||
| assert created_tool_provider.name == tool_name | |||
| assert created_tool_provider.label == tool_label | |||
| assert created_tool_provider.icon == json.dumps(tool_icon) | |||
| assert created_tool_provider.description == tool_description | |||
| assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters) | |||
| assert created_tool_provider.privacy_policy == tool_privacy_policy | |||
| assert created_tool_provider.version == workflow.version | |||
| assert created_tool_provider.user_id == account.id | |||
| assert created_tool_provider.tenant_id == account.current_tenant.id | |||
| assert created_tool_provider.app_id == app.id | |||
| # Verify external service calls | |||
| mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called_once() | |||
| mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once() | |||
| mock_external_service_dependencies[ | |||
| "tool_transform_service" | |||
| ].workflow_provider_to_controller.assert_called_once() | |||
| def test_create_workflow_tool_duplicate_name_error( | |||
| self, db_session_with_containers, mock_external_service_dependencies | |||
| ): | |||
| """ | |||
| Test workflow tool creation fails when name already exists. | |||
| This test verifies: | |||
| - Proper error handling for duplicate tool names | |||
| - Database constraint enforcement | |||
| - Correct error message | |||
| """ | |||
| fake = Faker() | |||
| # Create test data | |||
| app, account, workflow = self._create_test_app_and_account( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| # Create first workflow tool | |||
| first_tool_name = fake.word() | |||
| first_tool_parameters = self._create_test_workflow_tool_parameters() | |||
| WorkflowToolManageService.create_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_app_id=app.id, | |||
| name=first_tool_name, | |||
| label=fake.word(), | |||
| icon={"type": "emoji", "emoji": "🔧"}, | |||
| description=fake.text(max_nb_chars=200), | |||
| parameters=first_tool_parameters, | |||
| ) | |||
| # Attempt to create second workflow tool with same name | |||
| second_tool_parameters = self._create_test_workflow_tool_parameters() | |||
| with pytest.raises(ValueError) as exc_info: | |||
| WorkflowToolManageService.create_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_app_id=app.id, | |||
| name=first_tool_name, # Same name | |||
| label=fake.word(), | |||
| icon={"type": "emoji", "emoji": "⚙️"}, | |||
| description=fake.text(max_nb_chars=200), | |||
| parameters=second_tool_parameters, | |||
| ) | |||
| # Verify error message | |||
| assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value) | |||
| # Verify only one tool was created | |||
| from extensions.ext_database import db | |||
| tool_count = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .where( | |||
| WorkflowToolProvider.tenant_id == account.current_tenant.id, | |||
| ) | |||
| .count() | |||
| ) | |||
| assert tool_count == 1 | |||
| def test_create_workflow_tool_invalid_app_error( | |||
| self, db_session_with_containers, mock_external_service_dependencies | |||
| ): | |||
| """ | |||
| Test workflow tool creation fails when app does not exist. | |||
| This test verifies: | |||
| - Proper error handling for non-existent apps | |||
| - Correct error message | |||
| - No database changes when app is invalid | |||
| """ | |||
| fake = Faker() | |||
| # Create test data | |||
| app, account, workflow = self._create_test_app_and_account( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| # Generate non-existent app ID | |||
| non_existent_app_id = fake.uuid4() | |||
| # Attempt to create workflow tool with non-existent app | |||
| tool_parameters = self._create_test_workflow_tool_parameters() | |||
| with pytest.raises(ValueError) as exc_info: | |||
| WorkflowToolManageService.create_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_app_id=non_existent_app_id, # Non-existent app ID | |||
| name=fake.word(), | |||
| label=fake.word(), | |||
| icon={"type": "emoji", "emoji": "🔧"}, | |||
| description=fake.text(max_nb_chars=200), | |||
| parameters=tool_parameters, | |||
| ) | |||
| # Verify error message | |||
| assert f"App {non_existent_app_id} not found" in str(exc_info.value) | |||
| # Verify no workflow tool was created | |||
| from extensions.ext_database import db | |||
| tool_count = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .where( | |||
| WorkflowToolProvider.tenant_id == account.current_tenant.id, | |||
| ) | |||
| .count() | |||
| ) | |||
| assert tool_count == 0 | |||
| def test_create_workflow_tool_invalid_parameters_error( | |||
| self, db_session_with_containers, mock_external_service_dependencies | |||
| ): | |||
| """ | |||
| Test workflow tool creation fails when parameters are invalid. | |||
| This test verifies: | |||
| - Proper error handling for invalid parameter configurations | |||
| - Parameter validation enforcement | |||
| - Correct error message | |||
| """ | |||
| fake = Faker() | |||
| # Create test data | |||
| app, account, workflow = self._create_test_app_and_account( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| # Setup invalid workflow tool parameters (missing required fields) | |||
| invalid_parameters = [ | |||
| { | |||
| "name": "input_text", | |||
| # Missing description and form fields | |||
| "type": "string", | |||
| "required": True, | |||
| } | |||
| ] | |||
| # Attempt to create workflow tool with invalid parameters | |||
| with pytest.raises(ValueError) as exc_info: | |||
| WorkflowToolManageService.create_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_app_id=app.id, | |||
| name=fake.word(), | |||
| label=fake.word(), | |||
| icon={"type": "emoji", "emoji": "🔧"}, | |||
| description=fake.text(max_nb_chars=200), | |||
| parameters=invalid_parameters, | |||
| ) | |||
| # Verify error message contains validation error | |||
| assert "validation error" in str(exc_info.value).lower() | |||
| # Verify no workflow tool was created | |||
| from extensions.ext_database import db | |||
| tool_count = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .where( | |||
| WorkflowToolProvider.tenant_id == account.current_tenant.id, | |||
| ) | |||
| .count() | |||
| ) | |||
| assert tool_count == 0 | |||
| def test_create_workflow_tool_duplicate_app_id_error( | |||
| self, db_session_with_containers, mock_external_service_dependencies | |||
| ): | |||
| """ | |||
| Test workflow tool creation fails when app_id already exists. | |||
| This test verifies: | |||
| - Proper error handling for duplicate app_id | |||
| - Database constraint enforcement for app_id uniqueness | |||
| - Correct error message | |||
| """ | |||
| fake = Faker() | |||
| # Create test data | |||
| app, account, workflow = self._create_test_app_and_account( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| # Create first workflow tool | |||
| first_tool_name = fake.word() | |||
| first_tool_parameters = self._create_test_workflow_tool_parameters() | |||
| WorkflowToolManageService.create_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_app_id=app.id, | |||
| name=first_tool_name, | |||
| label=fake.word(), | |||
| icon={"type": "emoji", "emoji": "🔧"}, | |||
| description=fake.text(max_nb_chars=200), | |||
| parameters=first_tool_parameters, | |||
| ) | |||
| # Attempt to create second workflow tool with same app_id but different name | |||
| second_tool_name = fake.word() | |||
| second_tool_parameters = self._create_test_workflow_tool_parameters() | |||
| with pytest.raises(ValueError) as exc_info: | |||
| WorkflowToolManageService.create_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_app_id=app.id, # Same app_id | |||
| name=second_tool_name, # Different name | |||
| label=fake.word(), | |||
| icon={"type": "emoji", "emoji": "⚙️"}, | |||
| description=fake.text(max_nb_chars=200), | |||
| parameters=second_tool_parameters, | |||
| ) | |||
| # Verify error message | |||
| assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value) | |||
| # Verify only one tool was created | |||
| from extensions.ext_database import db | |||
| tool_count = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .where( | |||
| WorkflowToolProvider.tenant_id == account.current_tenant.id, | |||
| ) | |||
| .count() | |||
| ) | |||
| assert tool_count == 1 | |||
| def test_create_workflow_tool_workflow_not_found_error( | |||
| self, db_session_with_containers, mock_external_service_dependencies | |||
| ): | |||
| """ | |||
| Test workflow tool creation fails when app has no workflow. | |||
| This test verifies: | |||
| - Proper error handling for apps without workflows | |||
| - Correct error message | |||
| - No database changes when workflow is missing | |||
| """ | |||
| fake = Faker() | |||
| # Create test data but without workflow | |||
| app, account, _ = self._create_test_app_and_account( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| # Remove workflow reference from app | |||
| from extensions.ext_database import db | |||
| app.workflow_id = None | |||
| db.session.commit() | |||
| # Attempt to create workflow tool for app without workflow | |||
| tool_parameters = self._create_test_workflow_tool_parameters() | |||
| with pytest.raises(ValueError) as exc_info: | |||
| WorkflowToolManageService.create_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_app_id=app.id, | |||
| name=fake.word(), | |||
| label=fake.word(), | |||
| icon={"type": "emoji", "emoji": "🔧"}, | |||
| description=fake.text(max_nb_chars=200), | |||
| parameters=tool_parameters, | |||
| ) | |||
| # Verify error message | |||
| assert f"Workflow not found for app {app.id}" in str(exc_info.value) | |||
| # Verify no workflow tool was created | |||
| tool_count = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .where( | |||
| WorkflowToolProvider.tenant_id == account.current_tenant.id, | |||
| ) | |||
| .count() | |||
| ) | |||
| assert tool_count == 0 | |||
| def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): | |||
| """ | |||
| Test successful workflow tool update with valid parameters. | |||
| This test verifies: | |||
| - Proper workflow tool update with all required fields | |||
| - Correct database state after update | |||
| - Proper relationship maintenance | |||
| - External service integration | |||
| - Return value correctness | |||
| """ | |||
| fake = Faker() | |||
| # Create test data | |||
| app, account, workflow = self._create_test_app_and_account( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| # Create initial workflow tool | |||
| initial_tool_name = fake.word() | |||
| initial_tool_parameters = self._create_test_workflow_tool_parameters() | |||
| WorkflowToolManageService.create_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_app_id=app.id, | |||
| name=initial_tool_name, | |||
| label=fake.word(), | |||
| icon={"type": "emoji", "emoji": "🔧"}, | |||
| description=fake.text(max_nb_chars=200), | |||
| parameters=initial_tool_parameters, | |||
| ) | |||
| # Get the created tool | |||
| from extensions.ext_database import db | |||
| created_tool = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .where( | |||
| WorkflowToolProvider.tenant_id == account.current_tenant.id, | |||
| WorkflowToolProvider.app_id == app.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| # Setup update parameters | |||
| updated_tool_name = fake.word() | |||
| updated_tool_label = fake.word() | |||
| updated_tool_icon = {"type": "emoji", "emoji": "⚙️"} | |||
| updated_tool_description = fake.text(max_nb_chars=200) | |||
| updated_tool_parameters = self._create_test_workflow_tool_parameters() | |||
| updated_tool_privacy_policy = fake.text(max_nb_chars=100) | |||
| updated_tool_labels = ["automation", "updated"] | |||
| # Execute the update method | |||
| result = WorkflowToolManageService.update_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_tool_id=created_tool.id, | |||
| name=updated_tool_name, | |||
| label=updated_tool_label, | |||
| icon=updated_tool_icon, | |||
| description=updated_tool_description, | |||
| parameters=updated_tool_parameters, | |||
| privacy_policy=updated_tool_privacy_policy, | |||
| labels=updated_tool_labels, | |||
| ) | |||
| # Verify the result | |||
| assert result == {"result": "success"} | |||
| # Verify database state was updated | |||
| db.session.refresh(created_tool) | |||
| assert created_tool.name == updated_tool_name | |||
| assert created_tool.label == updated_tool_label | |||
| assert created_tool.icon == json.dumps(updated_tool_icon) | |||
| assert created_tool.description == updated_tool_description | |||
| assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters) | |||
| assert created_tool.privacy_policy == updated_tool_privacy_policy | |||
| assert created_tool.version == workflow.version | |||
| assert created_tool.updated_at is not None | |||
| # Verify external service calls | |||
| mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called() | |||
| mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called() | |||
| mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called() | |||
| def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): | |||
| """ | |||
| Test workflow tool update fails when tool does not exist. | |||
| This test verifies: | |||
| - Proper error handling for non-existent tools | |||
| - Correct error message | |||
| - No database changes when tool is invalid | |||
| """ | |||
| fake = Faker() | |||
| # Create test data | |||
| app, account, workflow = self._create_test_app_and_account( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| # Generate non-existent tool ID | |||
| non_existent_tool_id = fake.uuid4() | |||
| # Attempt to update non-existent workflow tool | |||
| tool_parameters = self._create_test_workflow_tool_parameters() | |||
| with pytest.raises(ValueError) as exc_info: | |||
| WorkflowToolManageService.update_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_tool_id=non_existent_tool_id, # Non-existent tool ID | |||
| name=fake.word(), | |||
| label=fake.word(), | |||
| icon={"type": "emoji", "emoji": "🔧"}, | |||
| description=fake.text(max_nb_chars=200), | |||
| parameters=tool_parameters, | |||
| ) | |||
| # Verify error message | |||
| assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value) | |||
| # Verify no workflow tool was created | |||
| from extensions.ext_database import db | |||
| tool_count = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .where( | |||
| WorkflowToolProvider.tenant_id == account.current_tenant.id, | |||
| ) | |||
| .count() | |||
| ) | |||
| assert tool_count == 0 | |||
| def test_update_workflow_tool_same_name_success( | |||
| self, db_session_with_containers, mock_external_service_dependencies | |||
| ): | |||
| """ | |||
| Test workflow tool update succeeds when keeping the same name. | |||
| This test verifies: | |||
| - Proper handling when updating tool with same name | |||
| - Database state maintenance | |||
| - Update timestamp is set | |||
| """ | |||
| fake = Faker() | |||
| # Create test data | |||
| app, account, workflow = self._create_test_app_and_account( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| # Create first workflow tool | |||
| first_tool_name = fake.word() | |||
| first_tool_parameters = self._create_test_workflow_tool_parameters() | |||
| WorkflowToolManageService.create_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_app_id=app.id, | |||
| name=first_tool_name, | |||
| label=fake.word(), | |||
| icon={"type": "emoji", "emoji": "🔧"}, | |||
| description=fake.text(max_nb_chars=200), | |||
| parameters=first_tool_parameters, | |||
| ) | |||
| # Get the created tool | |||
| from extensions.ext_database import db | |||
| created_tool = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .where( | |||
| WorkflowToolProvider.tenant_id == account.current_tenant.id, | |||
| WorkflowToolProvider.app_id == app.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| # Attempt to update tool with same name (should not fail) | |||
| result = WorkflowToolManageService.update_workflow_tool( | |||
| user_id=account.id, | |||
| tenant_id=account.current_tenant.id, | |||
| workflow_tool_id=created_tool.id, | |||
| name=first_tool_name, # Same name | |||
| label=fake.word(), | |||
| icon={"type": "emoji", "emoji": "⚙️"}, | |||
| description=fake.text(max_nb_chars=200), | |||
| parameters=first_tool_parameters, | |||
| ) | |||
| # Verify update was successful | |||
| assert result == {"result": "success"} | |||
| # Verify tool still exists with the same name | |||
| db.session.refresh(created_tool) | |||
| assert created_tool.name == first_tool_name | |||
| assert created_tool.updated_at is not None | |||
| @@ -14,11 +14,5 @@ uv run --directory api --dev ruff format ./ | |||
| # run dotenv-linter linter | |||
| uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example | |||
| # run import-linter | |||
| uv run --directory api --dev lint-imports | |||
| # run ty check | |||
| dev/ty-check | |||
| # run mypy check | |||
| dev/mypy-check | |||
| @@ -41,15 +41,6 @@ if $api_modified; then | |||
| echo "Please run 'dev/reformat' to fix the fixable linting errors." | |||
| exit 1 | |||
| fi | |||
| # run ty checks | |||
| uv run --directory api --dev ty check || status=$? | |||
| status=${status:-0} | |||
| if [ $status -ne 0 ]; then | |||
| echo "ty type checker on api module error, exit code: $status" | |||
| echo "Please run 'dev/ty-check' to check the type errors." | |||
| exit 1 | |||
| fi | |||
| fi | |||
| if $web_modified; then | |||
| @@ -38,7 +38,7 @@ const Field: FC<Props> = ({ | |||
| <div className={cn(className, inline && 'flex w-full items-center justify-between')}> | |||
| <div | |||
| onClick={() => supportFold && toggleFold()} | |||
| className={cn('sticky top-0 z-10 flex items-center justify-between bg-components-panel-bg', supportFold && 'cursor-pointer')}> | |||
| className={cn('sticky top-0 flex items-center justify-between bg-components-panel-bg', supportFold && 'cursor-pointer')}> | |||
| <div className='flex h-6 items-center'> | |||
| <div className={cn(isSubTitle ? 'system-xs-medium-uppercase text-text-tertiary' : 'system-sm-semibold-uppercase text-text-secondary')}> | |||
| {title} {required && <span className='text-text-destructive'>*</span>} | |||