Преглед на файлове

Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

tags/2.0.0-beta.1
-LAN- преди 2 месеца
родител
ревизия
8c41d95d03
No account linked to committer's email address

+ 0
- 11
api/child_class.py Целия файл

@@ -1,11 +0,0 @@
from tests.integration_tests.utils.parent_class import ParentClass


class ChildClass(ParentClass):
"""Test child class for module import helper tests"""

def __init__(self, name):
super().__init__(name)

def get_name(self):
return f"Child: {self.name}"

+ 1
- 1
api/controllers/console/app/workflow.py Целия файл

@@ -532,7 +532,7 @@ class PublishedWorkflowApi(Resource):
)

app_model.workflow_id = workflow.id
db.session.commit()
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id

workflow_created_at = TimestampField().format(workflow.created_at)


+ 2
- 1
api/controllers/console/datasets/data_source.py Целия файл

@@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
@@ -214,7 +215,7 @@ class DataSourceNotionApi(Resource):
workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],

+ 6
- 3
api/controllers/console/datasets/datasets.py Целия файл

@@ -21,6 +21,7 @@ from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
@@ -422,7 +423,9 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "notion_import":
@@ -431,7 +434,7 @@ class DatasetIndexingEstimateApi(Resource):
workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
@@ -445,7 +448,7 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],

+ 5
- 4
api/controllers/console/datasets/datasets_document.py Целия файл

@@ -40,6 +40,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from fields.document_fields import (
@@ -425,7 +426,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")

extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
)

indexing_runner = IndexingRunner()
@@ -485,13 +486,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")

extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)

elif document.data_source_type == "notion_import":
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
@@ -503,7 +504,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],

+ 6
- 3
api/core/indexing_runner.py Целия файл

@@ -19,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
@@ -340,7 +341,9 @@ class IndexingRunner:

if file_detail:
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
elif dataset_document.data_source_type == "notion_import":
@@ -351,7 +354,7 @@ class IndexingRunner:
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
@@ -371,7 +374,7 @@ class IndexingRunner:
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],

+ 2
- 2
api/core/rag/extractor/extract_processor.py Целия файл

@@ -45,7 +45,7 @@ class ExtractProcessor:
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
) -> Union[list[Document], str]:
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=upload_file, document_model="text_model"
datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model"
)
if return_text:
delimiter = "\n"
@@ -76,7 +76,7 @@ class ExtractProcessor:
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model")
if return_text:
delimiter = "\n"
return delimiter.join(

+ 1
- 1
api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py Целия файл

@@ -87,7 +87,7 @@ class ClickZettaVolumeConfig(BaseModel):
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
# 暂时禁用权限检查功能,直接设置为false
# Temporarily disable permission check feature, set directly to false
values.setdefault("permission_check", False)

# Validate required fields

+ 95
- 95
api/extensions/storage/clickzetta_volume/file_lifecycle.py Целия файл

@@ -1,7 +1,7 @@
"""ClickZetta Volume文件生命周期管理
"""ClickZetta Volume file lifecycle management

该模块提供文件版本控制、自动清理、备份和恢复等生命周期管理功能。
支持知识库文件的完整生命周期管理。
This module provides file lifecycle management features including version control, automatic cleanup, backup and restore.
Supports complete lifecycle management for knowledge base files.
"""

import json
@@ -15,17 +15,17 @@ logger = logging.getLogger(__name__)


class FileStatus(Enum):
"""文件状态枚举"""
"""File status enumeration"""

ACTIVE = "active" # 活跃状态
ARCHIVED = "archived" # 已归档
DELETED = "deleted" # 已删除(软删除)
BACKUP = "backup" # 备份文件
ACTIVE = "active" # Active status
ARCHIVED = "archived" # Archived
DELETED = "deleted" # Deleted (soft delete)
BACKUP = "backup" # Backup file


@dataclass
class FileMetadata:
"""文件元数据"""
"""File metadata"""

filename: str
size: int | None
@@ -38,7 +38,7 @@ class FileMetadata:
parent_version: Optional[int] = None

def to_dict(self) -> dict:
"""转换为字典格式"""
"""Convert to dictionary format"""
data = asdict(self)
data["created_at"] = self.created_at.isoformat()
data["modified_at"] = self.modified_at.isoformat()
@@ -47,7 +47,7 @@ class FileMetadata:

@classmethod
def from_dict(cls, data: dict) -> "FileMetadata":
"""从字典创建实例"""
"""Create instance from dictionary"""
data = data.copy()
data["created_at"] = datetime.fromisoformat(data["created_at"])
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
@@ -56,14 +56,14 @@ class FileMetadata:


class FileLifecycleManager:
"""文件生命周期管理器"""
"""File lifecycle manager"""

def __init__(self, storage, dataset_id: Optional[str] = None):
"""初始化生命周期管理器
"""Initialize lifecycle manager

Args:
storage: ClickZetta Volume存储实例
dataset_id: 数据集ID(用于Table Volume)
storage: ClickZetta Volume storage instance
dataset_id: Dataset ID (for Table Volume)
"""
self._storage = storage
self._dataset_id = dataset_id
@@ -72,21 +72,21 @@ class FileLifecycleManager:
self._backup_prefix = ".backups/"
self._deleted_prefix = ".deleted/"

# 获取权限管理器(如果存在)
# Get permission manager (if exists)
self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None)

def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata:
"""保存文件并管理生命周期
"""Save file and manage lifecycle

Args:
filename: 文件名
data: 文件内容
tags: 文件标签
filename: File name
data: File content
tags: File tags

Returns:
文件元数据
File metadata
"""
# 权限检查
# Permission check
if not self._check_permission(filename, "save"):
from .volume_permissions import VolumePermissionError

@@ -98,28 +98,28 @@ class FileLifecycleManager:
)

try:
# 1. 检查是否存在旧版本
# 1. Check if old version exists
metadata_dict = self._load_metadata()
current_metadata = metadata_dict.get(filename)

# 2. 如果存在旧版本,创建版本备份
# 2. If old version exists, create version backup
if current_metadata:
self._create_version_backup(filename, current_metadata)

# 3. 计算文件信息
# 3. Calculate file information
now = datetime.now()
checksum = self._calculate_checksum(data)
new_version = (current_metadata["version"] + 1) if current_metadata else 1

# 4. 保存新文件
# 4. Save new file
self._storage.save(filename, data)

# 5. 创建元数据
# 5. Create metadata
created_at = now
parent_version = None

if current_metadata:
# 如果created_at是字符串,转换为datetime
# If created_at is string, convert to datetime
if isinstance(current_metadata["created_at"], str):
created_at = datetime.fromisoformat(current_metadata["created_at"])
else:
@@ -138,7 +138,7 @@ class FileLifecycleManager:
parent_version=parent_version,
)

# 6. 更新元数据
# 6. Update metadata
metadata_dict[filename] = file_metadata.to_dict()
self._save_metadata(metadata_dict)

@@ -150,13 +150,13 @@ class FileLifecycleManager:
raise

def get_file_metadata(self, filename: str) -> Optional[FileMetadata]:
"""获取文件元数据
"""Get file metadata

Args:
filename: 文件名
filename: File name

Returns:
文件元数据,如果不存在返回None
File metadata, returns None if not exists
"""
try:
metadata_dict = self._load_metadata()
@@ -168,37 +168,37 @@ class FileLifecycleManager:
return None

def list_file_versions(self, filename: str) -> list[FileMetadata]:
"""列出文件的所有版本
"""List all versions of a file

Args:
filename: 文件名
filename: File name

Returns:
文件版本列表,按版本号排序
File version list, sorted by version number
"""
try:
versions = []

# 获取当前版本
# Get current version
current_metadata = self.get_file_metadata(filename)
if current_metadata:
versions.append(current_metadata)

# 获取历史版本
# Get historical versions
try:
version_files = self._storage.scan(self._dataset_id or "", files=True)
for file_path in version_files:
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
# 解析版本号
# Parse version number
version_str = file_path.split(".v")[-1].split(".")[0]
try:
version_num = int(version_str)
# 这里简化处理,实际应该从版本文件中读取元数据
# 暂时创建基本的元数据信息
# Simplified processing here, should actually read metadata from version file
# Temporarily create basic metadata information
except ValueError:
continue
except:
# 如果无法扫描版本文件,只返回当前版本
# If cannot scan version files, only return current version
pass

return sorted(versions, key=lambda x: x.version or 0, reverse=True)
@@ -208,32 +208,32 @@ class FileLifecycleManager:
return []

def restore_version(self, filename: str, version: int) -> bool:
"""恢复文件到指定版本
"""Restore file to specified version

Args:
filename: 文件名
version: 要恢复的版本号
filename: File name
version: Version number to restore

Returns:
恢复是否成功
Whether restore succeeded
"""
try:
version_filename = f"{self._version_prefix}{filename}.v{version}"

# 检查版本文件是否存在
# Check if version file exists
if not self._storage.exists(version_filename):
logger.warning("Version %s of %s not found", version, filename)
return False

# 读取版本文件内容
# Read version file content
version_data = self._storage.load_once(version_filename)

# 保存当前版本为备份
# Save current version as backup
current_metadata = self.get_file_metadata(filename)
if current_metadata:
self._create_version_backup(filename, current_metadata.to_dict())

# 恢复文件
# Restore file
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
return True

@@ -242,21 +242,21 @@ class FileLifecycleManager:
return False

def archive_file(self, filename: str) -> bool:
"""归档文件
"""Archive file

Args:
filename: 文件名
filename: File name

Returns:
归档是否成功
Whether archive succeeded
"""
# 权限检查
# Permission check
if not self._check_permission(filename, "archive"):
logger.warning("Permission denied for archive operation on file: %s", filename)
return False

try:
# 更新文件状态为归档
# Update file status to archived
metadata_dict = self._load_metadata()
if filename not in metadata_dict:
logger.warning("File %s not found in metadata", filename)
@@ -275,36 +275,36 @@ class FileLifecycleManager:
return False

def soft_delete_file(self, filename: str) -> bool:
"""软删除文件(移动到删除目录)
"""Soft delete file (move to deleted directory)

Args:
filename: 文件名
filename: File name

Returns:
删除是否成功
Whether delete succeeded
"""
# 权限检查
# Permission check
if not self._check_permission(filename, "delete"):
logger.warning("Permission denied for soft delete operation on file: %s", filename)
return False

try:
# 检查文件是否存在
# Check if file exists
if not self._storage.exists(filename):
logger.warning("File %s not found", filename)
return False

# 读取文件内容
# Read file content
file_data = self._storage.load_once(filename)

# 移动到删除目录
# Move to deleted directory
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self._storage.save(deleted_filename, file_data)

# 删除原文件
# Delete original file
self._storage.delete(filename)

# 更新元数据
# Update metadata
metadata_dict = self._load_metadata()
if filename in metadata_dict:
metadata_dict[filename]["status"] = FileStatus.DELETED.value
@@ -319,27 +319,27 @@ class FileLifecycleManager:
return False

def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
"""清理旧版本文件
"""Cleanup old version files

Args:
max_versions: 保留的最大版本数
max_age_days: 版本文件的最大保留天数
max_versions: Maximum number of versions to keep
max_age_days: Maximum retention days for version files

Returns:
清理的文件数量
Number of files cleaned
"""
try:
cleaned_count = 0

# 获取所有版本文件
# Get all version files
try:
all_files = self._storage.scan(self._dataset_id or "", files=True)
version_files = [f for f in all_files if f.startswith(self._version_prefix)]

# 按文件分组
# Group by file
file_versions: dict[str, list[tuple[int, str]]] = {}
for version_file in version_files:
# 解析文件名和版本
# Parse filename and version
parts = version_file[len(self._version_prefix) :].split(".v")
if len(parts) >= 2:
base_filename = parts[0]
@@ -352,12 +352,12 @@ class FileLifecycleManager:
except ValueError:
continue

# 清理每个文件的旧版本
# Cleanup old versions for each file
for base_filename, versions in file_versions.items():
# 按版本号排序
# Sort by version number
versions.sort(key=lambda x: x[0], reverse=True)

# 保留最新的max_versions个版本,删除其余的
# Keep the newest max_versions versions, delete the rest
if len(versions) > max_versions:
to_delete = versions[max_versions:]
for version_num, version_file in to_delete:
@@ -377,10 +377,10 @@ class FileLifecycleManager:
return 0

def get_storage_statistics(self) -> dict[str, Any]:
"""获取存储统计信息
"""Get storage statistics

Returns:
存储统计字典
Storage statistics dictionary
"""
try:
metadata_dict = self._load_metadata()
@@ -402,7 +402,7 @@ class FileLifecycleManager:
for filename, metadata in metadata_dict.items():
file_meta = FileMetadata.from_dict(metadata)

# 统计文件状态
# Count file status
if file_meta.status == FileStatus.ACTIVE:
stats["active_files"] = (stats["active_files"] or 0) + 1
elif file_meta.status == FileStatus.ARCHIVED:
@@ -410,13 +410,13 @@ class FileLifecycleManager:
elif file_meta.status == FileStatus.DELETED:
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1

# 统计大小
# Count size
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)

# 统计版本
# Count versions
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)

# 找出最新和最旧的文件
# Find newest and oldest files
if oldest_date is None or file_meta.created_at < oldest_date:
oldest_date = file_meta.created_at
stats["oldest_file"] = filename
@@ -432,12 +432,12 @@ class FileLifecycleManager:
return {}

def _create_version_backup(self, filename: str, metadata: dict):
"""创建版本备份"""
"""Create version backup"""
try:
# 读取当前文件内容
# Read current file content
current_data = self._storage.load_once(filename)

# 保存为版本文件
# Save as version file
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
self._storage.save(version_filename, current_data)

@@ -447,7 +447,7 @@ class FileLifecycleManager:
logger.warning("Failed to create version backup for %s: %s", filename, e)

def _load_metadata(self) -> dict[str, Any]:
"""加载元数据文件"""
"""Load metadata file"""
try:
if self._storage.exists(self._metadata_file):
metadata_content = self._storage.load_once(self._metadata_file)
@@ -460,7 +460,7 @@ class FileLifecycleManager:
return {}

def _save_metadata(self, metadata_dict: dict):
"""保存元数据文件"""
"""Save metadata file"""
try:
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
@@ -470,45 +470,45 @@ class FileLifecycleManager:
raise

def _calculate_checksum(self, data: bytes) -> str:
"""计算文件校验和"""
"""Calculate file checksum"""
import hashlib

return hashlib.md5(data).hexdigest()

def _check_permission(self, filename: str, operation: str) -> bool:
"""检查文件操作权限
"""Check file operation permission

Args:
filename: 文件名
operation: 操作类型
filename: File name
operation: Operation type

Returns:
True if permission granted, False otherwise
"""
# 如果没有权限管理器,默认允许
# If no permission manager, allow by default
if not self._permission_manager:
return True

try:
# 根据操作类型映射到权限
# Map operation type to permission
operation_mapping = {
"save": "save",
"load": "load_once",
"delete": "delete",
"archive": "delete", # 归档需要删除权限
"restore": "save", # 恢复需要写权限
"cleanup": "delete", # 清理需要删除权限
"archive": "delete", # Archive requires delete permission
"restore": "save", # Restore requires write permission
"cleanup": "delete", # Cleanup requires delete permission
"read": "load_once",
"write": "save",
}

mapped_operation = operation_mapping.get(operation, operation)

# 检查权限
# Check permission
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
return bool(result)

except Exception as e:
logger.exception("Permission check failed for %s operation %s", filename, operation)
# 安全默认:权限检查失败时拒绝访问
# Safe default: deny access when permission check fails
return False

+ 110
- 109
api/extensions/storage/clickzetta_volume/volume_permissions.py Целия файл

@@ -1,7 +1,7 @@
"""ClickZetta Volume权限管理机制
"""ClickZetta Volume permission management mechanism

该模块提供Volume权限检查、验证和管理功能。
根据ClickZetta的权限模型,不同Volume类型有不同的权限要求。
This module provides Volume permission checking, validation and management features.
According to ClickZetta's permission model, different Volume types have different permission requirements.
"""

import logging
@@ -12,29 +12,29 @@ logger = logging.getLogger(__name__)


class VolumePermission(Enum):
"""Volume权限类型枚举"""
"""Volume permission type enumeration"""

READ = "SELECT" # 对应ClickZetta的SELECT权限
WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限
LIST = "SELECT" # 列出文件需要SELECT权限
DELETE = "INSERT,UPDATE,DELETE" # 删除文件需要写权限
USAGE = "USAGE" # External Volume需要的基本权限
READ = "SELECT" # Corresponds to ClickZetta's SELECT permission
WRITE = "INSERT,UPDATE,DELETE" # Corresponds to ClickZetta's write permissions
LIST = "SELECT" # Listing files requires SELECT permission
DELETE = "INSERT,UPDATE,DELETE" # Deleting files requires write permissions
USAGE = "USAGE" # Basic permission required for External Volume


class VolumePermissionManager:
"""Volume权限管理器"""
"""Volume permission manager"""

def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None):
"""初始化权限管理器
"""Initialize permission manager

Args:
connection_or_config: ClickZetta连接对象或配置字典
volume_type: Volume类型 (user|table|external)
volume_name: Volume名称 (用于external volume)
connection_or_config: ClickZetta connection object or configuration dictionary
volume_type: Volume type (user|table|external)
volume_name: Volume name (for external volume)
"""
# 支持两种初始化方式:连接对象或配置字典
# Support two initialization methods: connection object or configuration dictionary
if isinstance(connection_or_config, dict):
# 从配置字典创建连接
# Create connection from configuration dictionary
import clickzetta # type: ignore[import-untyped]

config = connection_or_config
@@ -50,7 +50,7 @@ class VolumePermissionManager:
self._volume_type = config.get("volume_type", volume_type)
self._volume_name = config.get("volume_name", volume_name)
else:
# 直接使用连接对象
# Use connection object directly
self._connection = connection_or_config
self._volume_type = volume_type
self._volume_name = volume_name
@@ -61,14 +61,14 @@ class VolumePermissionManager:
raise ValueError("volume_type is required")

self._permission_cache: dict[str, set[str]] = {}
self._current_username = None # 将从连接中获取当前用户名
self._current_username = None # Will get current username from connection

def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool:
"""检查用户是否有执行特定操作的权限
"""Check if user has permission to perform specific operation

Args:
operation: 要执行的操作类型
dataset_id: 数据集ID (用于table volume)
operation: Type of operation to perform
dataset_id: Dataset ID (for table volume)

Returns:
True if user has permission, False otherwise
@@ -89,20 +89,20 @@ class VolumePermissionManager:
return False

def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
"""检查User Volume权限
"""Check User Volume permission

User Volume权限规则:
- 用户对自己的User Volume有全部权限
- 只要用户能够连接到ClickZetta,就默认具有User Volume的基本权限
- 更注重连接身份验证,而不是复杂的权限检查
User Volume permission rules:
- User has full permissions on their own User Volume
- As long as user can connect to ClickZetta, they have basic User Volume permissions by default
- Focus more on connection authentication rather than complex permission checking
"""
try:
# 获取当前用户名
# Get current username
current_user = self._get_current_username()

# 检查基本连接状态
# Check basic connection status
with self._connection.cursor() as cursor:
# 简单的连接测试,如果能执行查询说明用户有基本权限
# Simple connection test, if query can be executed user has basic permissions
cursor.execute("SELECT 1")
result = cursor.fetchone()

@@ -121,17 +121,18 @@ class VolumePermissionManager:

except Exception as e:
logger.exception("User Volume permission check failed")
# 对于User Volume,如果权限检查失败,可能是配置问题,给出更友好的错误提示
# For User Volume, if permission check fails, it might be a configuration issue,
# provide friendlier error message
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
return False

def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool:
"""检查Table Volume权限
"""Check Table Volume permission

Table Volume权限规则:
- Table Volume权限继承对应表的权限
- SELECT权限 -> 可以READ/LIST文件
- INSERT,UPDATE,DELETE权限 -> 可以WRITE/DELETE文件
Table Volume permission rules:
- Table Volume permissions inherit from corresponding table permissions
- SELECT permission -> can READ/LIST files
- INSERT,UPDATE,DELETE permissions -> can WRITE/DELETE files
"""
if not dataset_id:
logger.warning("dataset_id is required for table volume permission check")
@@ -140,11 +141,11 @@ class VolumePermissionManager:
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id

try:
# 检查表权限
# Check table permissions
permissions = self._get_table_permissions(table_name)
required_permissions = set(operation.value.split(","))

# 检查是否有所需的所有权限
# Check if has all required permissions
has_permission = required_permissions.issubset(permissions)

logger.debug(
@@ -163,22 +164,22 @@ class VolumePermissionManager:
return False

def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
"""检查External Volume权限
"""Check External Volume permission

External Volume权限规则:
- 尝试获取对External Volume的权限
- 如果权限检查失败,进行备选验证
- 对于开发环境,提供更宽松的权限检查
External Volume permission rules:
- Try to get permissions for External Volume
- If permission check fails, perform fallback verification
- For development environment, provide more lenient permission checking
"""
if not self._volume_name:
logger.warning("volume_name is required for external volume permission check")
return False

try:
# 检查External Volume权限
# Check External Volume permissions
permissions = self._get_external_volume_permissions(self._volume_name)

# External Volume权限映射:根据操作类型确定所需权限
# External Volume permission mapping: determine required permissions based on operation type
required_permissions = set()

if operation in [VolumePermission.READ, VolumePermission.LIST]:
@@ -186,7 +187,7 @@ class VolumePermissionManager:
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
required_permissions.add("write")

# 检查是否有所需的所有权限
# Check if has all required permissions
has_permission = required_permissions.issubset(permissions)

logger.debug(
@@ -198,11 +199,11 @@ class VolumePermissionManager:
has_permission,
)

# 如果权限检查失败,尝试备选验证
# If permission check fails, try fallback verification
if not has_permission:
logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name)

# 备选验证:尝试列出Volume来验证基本访问权限
# Fallback verification: try listing Volume to verify basic access permissions
try:
with self._connection.cursor() as cursor:
cursor.execute("SHOW VOLUMES")
@@ -222,13 +223,13 @@ class VolumePermissionManager:
return False

def _get_table_permissions(self, table_name: str) -> set[str]:
"""获取用户对指定表的权限
"""Get user permissions for specified table

Args:
table_name: 表名
table_name: Table name

Returns:
用户对该表的权限集合
Set of user permissions for this table
"""
cache_key = f"table:{table_name}"

@@ -239,18 +240,18 @@ class VolumePermissionManager:

try:
with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查当前用户权限
# Use correct ClickZetta syntax to check current user permissions
cursor.execute("SHOW GRANTS")
grants = cursor.fetchall()

# 解析权限结果,查找对该表的权限
# Parse permission results, find permissions for this table
for grant in grants:
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
privilege = grant[0].upper()
object_type = grant[1].upper() if len(grant) > 1 else ""
object_name = grant[2] if len(grant) > 2 else ""

# 检查是否是对该表的权限
# Check if it's permission for this table
if (
object_type == "TABLE"
and object_name == table_name
@@ -263,7 +264,7 @@ class VolumePermissionManager:
else:
permissions.add(privilege)

# 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限
# If no explicit permissions found, try executing a simple query to verify permissions
if not permissions:
try:
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
@@ -273,15 +274,15 @@ class VolumePermissionManager:

except Exception as e:
logger.warning("Could not check table permissions for %s: %s", table_name, e)
# 安全默认:权限检查失败时拒绝访问
# Safe default: deny access when permission check fails
pass

# 缓存权限信息
# Cache permission information
self._permission_cache[cache_key] = permissions
return permissions

def _get_current_username(self) -> str:
"""获取当前用户名"""
"""Get current username"""
if self._current_username:
return self._current_username

@@ -298,7 +299,7 @@ class VolumePermissionManager:
return "unknown"

def _get_user_permissions(self, username: str) -> set[str]:
"""获取用户的基本权限集合"""
"""Get user's basic permission set"""
cache_key = f"user_permissions:{username}"

if cache_key in self._permission_cache:
@@ -308,17 +309,17 @@ class VolumePermissionManager:

try:
with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查当前用户权限
# Use correct ClickZetta syntax to check current user permissions
cursor.execute("SHOW GRANTS")
grants = cursor.fetchall()

# 解析权限结果,查找用户的基本权限
# Parse permission results, find user's basic permissions
for grant in grants:
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
privilege = grant[0].upper()
object_type = grant[1].upper() if len(grant) > 1 else ""

# 收集所有相关权限
# Collect all relevant permissions
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
if privilege == "ALL":
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
@@ -327,21 +328,21 @@ class VolumePermissionManager:

except Exception as e:
logger.warning("Could not check user permissions for %s: %s", username, e)
# 安全默认:权限检查失败时拒绝访问
# Safe default: deny access when permission check fails
pass

# 缓存权限信息
# Cache permission information
self._permission_cache[cache_key] = permissions
return permissions

def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
"""获取用户对指定External Volume的权限
"""Get user permissions for specified External Volume

Args:
volume_name: External Volume名称
volume_name: External Volume name

Returns:
用户对该Volume的权限集合
Set of user permissions for this Volume
"""
cache_key = f"external_volume:{volume_name}"

@@ -352,15 +353,15 @@ class VolumePermissionManager:

try:
with self._connection.cursor() as cursor:
# 使用正确的ClickZetta语法检查Volume权限
# Use correct ClickZetta syntax to check Volume permissions
logger.info("Checking permissions for volume: %s", volume_name)
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
grants = cursor.fetchall()

logger.info("Raw grants result for %s: %s", volume_name, grants)

# 解析权限结果
# 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
# Parse permission results
# Format: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
# grantee_name, grantor_name, grant_option, granted_time)
for grant in grants:
logger.info("Processing grant: %s", grant)
@@ -378,7 +379,7 @@ class VolumePermissionManager:
object_name,
)

# 检查是否是对该Volume的权限或者是层级权限
# Check if it's permission for this Volume or hierarchical permission
if (
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
@@ -399,14 +400,14 @@ class VolumePermissionManager:

logger.info("Final permissions for %s: %s", volume_name, permissions)

# 如果没有找到明确的权限,尝试查看Volume列表来验证基本权限
# If no explicit permissions found, try viewing Volume list to verify basic permissions
if not permissions:
try:
cursor.execute("SHOW VOLUMES")
volumes = cursor.fetchall()
for volume in volumes:
if len(volume) > 0 and volume[0] == volume_name:
permissions.add("read") # 至少有读权限
permissions.add("read") # At least has read permission
logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name)
break
except Exception:
@@ -414,7 +415,7 @@ class VolumePermissionManager:

except Exception as e:
logger.warning("Could not check external volume permissions for %s: %s", volume_name, e)
# 在权限检查失败时,尝试基本的Volume访问验证
# When permission check fails, try basic Volume access verification
try:
with self._connection.cursor() as cursor:
cursor.execute("SHOW VOLUMES")
@@ -423,30 +424,30 @@ class VolumePermissionManager:
if len(volume) > 0 and volume[0] == volume_name:
logger.info("Basic volume access verified for %s", volume_name)
permissions.add("read")
permissions.add("write") # 假设有写权限
permissions.add("write") # Assume has write permission
break
except Exception as basic_e:
logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e)
# 最后的备选方案:假设有基本权限
# Last fallback: assume basic permissions
permissions.add("read")

# 缓存权限信息
# Cache permission information
self._permission_cache[cache_key] = permissions
return permissions

def clear_permission_cache(self):
"""清空权限缓存"""
"""Clear permission cache"""
self._permission_cache.clear()
logger.debug("Permission cache cleared")

def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]:
"""获取权限摘要
"""Get permission summary

Args:
dataset_id: 数据集ID (用于table volume)
dataset_id: Dataset ID (for table volume)

Returns:
权限摘要字典
Permission summary dictionary
"""
summary = {}

@@ -456,43 +457,43 @@ class VolumePermissionManager:
return summary

def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
"""检查文件路径的权限继承
"""Check permission inheritance for file path

Args:
file_path: 文件路径
operation: 要执行的操作
file_path: File path
operation: Operation to perform

Returns:
True if user has permission, False otherwise
"""
try:
# 解析文件路径
# Parse file path
path_parts = file_path.strip("/").split("/")

if not path_parts:
logger.warning("Invalid file path for permission inheritance check")
return False

# 对于Table Volume,第一层是dataset_id
# For Table Volume, first layer is dataset_id
if self._volume_type == "table":
if len(path_parts) < 1:
return False

dataset_id = path_parts[0]

# 检查对dataset的权限
# Check permissions for dataset
has_dataset_permission = self.check_permission(operation, dataset_id)

if not has_dataset_permission:
logger.debug("Permission denied for dataset %s", dataset_id)
return False

# 检查路径遍历攻击
# Check path traversal attack
if self._contains_path_traversal(file_path):
logger.warning("Path traversal attack detected: %s", file_path)
return False

# 检查是否访问敏感目录
# Check if accessing sensitive directory
if self._is_sensitive_path(file_path):
logger.warning("Access to sensitive path denied: %s", file_path)
return False
@@ -501,20 +502,20 @@ class VolumePermissionManager:
return True

elif self._volume_type == "user":
# User Volume的权限继承
# User Volume permission inheritance
current_user = self._get_current_username()

# 检查是否试图访问其他用户的目录
# Check if attempting to access other user's directory
if len(path_parts) > 1 and path_parts[0] != current_user:
logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0])
return False

# 检查基本权限
# Check basic permissions
return self.check_permission(operation)

elif self._volume_type == "external":
# External Volume的权限继承
# 检查对External Volume的权限
# External Volume permission inheritance
# Check permissions for External Volume
return self.check_permission(operation)

else:
@@ -526,8 +527,8 @@ class VolumePermissionManager:
return False

def _contains_path_traversal(self, file_path: str) -> bool:
"""检查路径是否包含路径遍历攻击"""
# 检查常见的路径遍历模式
"""Check if path contains path traversal attack"""
# Check common path traversal patterns
traversal_patterns = [
"../",
"..\\",
@@ -547,18 +548,18 @@ class VolumePermissionManager:
if pattern in file_path_lower:
return True

# 检查绝对路径
# Check absolute path
if file_path.startswith("/") or file_path.startswith("\\"):
return True

# 检查Windows驱动器路径
# Check Windows drive path
if len(file_path) >= 2 and file_path[1] == ":":
return True

return False

def _is_sensitive_path(self, file_path: str) -> bool:
"""检查路径是否为敏感路径"""
"""Check if path is sensitive path"""
sensitive_patterns = [
"passwd",
"shadow",
@@ -582,11 +583,11 @@ class VolumePermissionManager:
return any(pattern in file_path_lower for pattern in sensitive_patterns)

def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool:
"""验证操作权限
"""Validate operation permission

Args:
operation: 操作名称 (save|load|exists|delete|scan)
dataset_id: 数据集ID
operation: Operation name (save|load|exists|delete|scan)
dataset_id: Dataset ID

Returns:
True if operation is allowed, False otherwise
@@ -611,7 +612,7 @@ class VolumePermissionManager:


class VolumePermissionError(Exception):
"""Volume权限错误异常"""
"""Volume permission error exception"""

def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None):
self.operation = operation
@@ -623,15 +624,15 @@ class VolumePermissionError(Exception):
def check_volume_permission(
permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None
) -> None:
"""权限检查装饰器函数
"""Permission check decorator function

Args:
permission_manager: 权限管理器
operation: 操作名称
dataset_id: 数据集ID
permission_manager: Permission manager
operation: Operation name
dataset_id: Dataset ID

Raises:
VolumePermissionError: 如果没有权限
VolumePermissionError: If no permission
"""
if not permission_manager.validate_operation(operation, dataset_id):
error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"

+ 0
- 11
api/lazy_load_class.py Целия файл

@@ -1,11 +0,0 @@
from tests.integration_tests.utils.parent_class import ParentClass


class LazyLoadChildClass(ParentClass):
"""Test lazy load child class for module import helper tests"""

def __init__(self, name):
super().__init__(name)

def get_name(self):
return self.name

+ 716
- 0
api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py Целия файл

@@ -0,0 +1,716 @@
import json
from unittest.mock import patch

import pytest
from faker import Faker

from models.tools import WorkflowToolProvider
from models.workflow import Workflow as WorkflowModel
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService


class TestWorkflowToolManageService:
"""Integration tests for WorkflowToolManageService using testcontainers."""

@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch(
"services.tools.workflow_tools_manage_service.WorkflowToolProviderController"
) as mock_workflow_tool_provider_controller,
patch("services.tools.workflow_tools_manage_service.ToolLabelManager") as mock_tool_label_manager,
patch("services.tools.workflow_tools_manage_service.ToolTransformService") as mock_tool_transform_service,
):
# Setup default mock returns for app service
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None

# Setup default mock returns for account service
mock_account_feature_service.get_system_features.return_value.is_allow_register = True

# Mock ModelManager for model configuration
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")

# Mock WorkflowToolProviderController
mock_workflow_tool_provider_controller.from_db.return_value = None

# Mock ToolLabelManager
mock_tool_label_manager.update_tool_labels.return_value = None

# Mock ToolTransformService
mock_tool_transform_service.workflow_provider_to_controller.return_value = None

yield {
"feature_service": mock_feature_service,
"enterprise_service": mock_enterprise_service,
"model_manager": mock_model_manager,
"account_feature_service": mock_account_feature_service,
"workflow_tool_provider_controller": mock_workflow_tool_provider_controller,
"tool_label_manager": mock_tool_label_manager,
"tool_transform_service": mock_tool_transform_service,
}

def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.

Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies

Returns:
tuple: (app, account, workflow) - Created app, account and workflow instances
"""
fake = Faker()

# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True

# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant

# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "workflow",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}

app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)

# Create workflow for the app
workflow = WorkflowModel(
tenant_id=tenant.id,
app_id=app.id,
type="workflow",
version="1.0.0",
graph=json.dumps({}),
features=json.dumps({}),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
)

from extensions.ext_database import db

db.session.add(workflow)
db.session.commit()

# Update app to reference the workflow
app.workflow_id = workflow.id
db.session.commit()

return app, account, workflow

def _create_test_workflow_tool_parameters(self):
"""Helper method to create valid workflow tool parameters."""
return [
{
"name": "input_text",
"description": "Input text for processing",
"form": "form",
"type": "string",
"required": True,
},
{
"name": "output_format",
"description": "Output format specification",
"form": "form",
"type": "select",
"required": False,
},
]

def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful workflow tool creation with valid parameters.

This test verifies:
- Proper workflow tool creation with all required fields
- Correct database state after creation
- Proper relationship establishment
- External service integration
- Return value correctness
"""
fake = Faker()

# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)

# Setup workflow tool creation parameters
tool_name = fake.word()
tool_label = fake.word()
tool_icon = {"type": "emoji", "emoji": "🔧"}
tool_description = fake.text(max_nb_chars=200)
tool_parameters = self._create_test_workflow_tool_parameters()
tool_privacy_policy = fake.text(max_nb_chars=100)
tool_labels = ["automation", "workflow"]

# Execute the method under test
result = WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=tool_name,
label=tool_label,
icon=tool_icon,
description=tool_description,
parameters=tool_parameters,
privacy_policy=tool_privacy_policy,
labels=tool_labels,
)

# Verify the result
assert result == {"result": "success"}

# Verify database state
from extensions.ext_database import db

# Check if workflow tool provider was created
created_tool_provider = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
)
.first()
)

assert created_tool_provider is not None
assert created_tool_provider.name == tool_name
assert created_tool_provider.label == tool_label
assert created_tool_provider.icon == json.dumps(tool_icon)
assert created_tool_provider.description == tool_description
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters)
assert created_tool_provider.privacy_policy == tool_privacy_policy
assert created_tool_provider.version == workflow.version
assert created_tool_provider.user_id == account.id
assert created_tool_provider.tenant_id == account.current_tenant.id
assert created_tool_provider.app_id == app.id

# Verify external service calls
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called_once()
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
mock_external_service_dependencies[
"tool_transform_service"
].workflow_provider_to_controller.assert_called_once()

def test_create_workflow_tool_duplicate_name_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when name already exists.

This test verifies:
- Proper error handling for duplicate tool names
- Database constraint enforcement
- Correct error message
"""
fake = Faker()

# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)

# Create first workflow tool
first_tool_name = fake.word()
first_tool_parameters = self._create_test_workflow_tool_parameters()

WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=first_tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=first_tool_parameters,
)

# Attempt to create second workflow tool with same name
second_tool_parameters = self._create_test_workflow_tool_parameters()

with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=first_tool_name, # Same name
label=fake.word(),
icon={"type": "emoji", "emoji": "⚙️"},
description=fake.text(max_nb_chars=200),
parameters=second_tool_parameters,
)

# Verify error message
assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value)

# Verify only one tool was created
from extensions.ext_database import db

tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)

assert tool_count == 1

def test_create_workflow_tool_invalid_app_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app does not exist.

This test verifies:
- Proper error handling for non-existent apps
- Correct error message
- No database changes when app is invalid
"""
fake = Faker()

# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)

# Generate non-existent app ID
non_existent_app_id = fake.uuid4()

# Attempt to create workflow tool with non-existent app
tool_parameters = self._create_test_workflow_tool_parameters()

with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=non_existent_app_id, # Non-existent app ID
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=tool_parameters,
)

# Verify error message
assert f"App {non_existent_app_id} not found" in str(exc_info.value)

# Verify no workflow tool was created
from extensions.ext_database import db

tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)

assert tool_count == 0

def test_create_workflow_tool_invalid_parameters_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when parameters are invalid.

This test verifies:
- Proper error handling for invalid parameter configurations
- Parameter validation enforcement
- Correct error message
"""
fake = Faker()

# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)

# Setup invalid workflow tool parameters (missing required fields)
invalid_parameters = [
{
"name": "input_text",
# Missing description and form fields
"type": "string",
"required": True,
}
]

# Attempt to create workflow tool with invalid parameters
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=invalid_parameters,
)

# Verify error message contains validation error
assert "validation error" in str(exc_info.value).lower()

# Verify no workflow tool was created
from extensions.ext_database import db

tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)

assert tool_count == 0

def test_create_workflow_tool_duplicate_app_id_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app_id already exists.

This test verifies:
- Proper error handling for duplicate app_id
- Database constraint enforcement for app_id uniqueness
- Correct error message
"""
fake = Faker()

# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)

# Create first workflow tool
first_tool_name = fake.word()
first_tool_parameters = self._create_test_workflow_tool_parameters()

WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=first_tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=first_tool_parameters,
)

# Attempt to create second workflow tool with same app_id but different name
second_tool_name = fake.word()
second_tool_parameters = self._create_test_workflow_tool_parameters()

with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id, # Same app_id
name=second_tool_name, # Different name
label=fake.word(),
icon={"type": "emoji", "emoji": "⚙️"},
description=fake.text(max_nb_chars=200),
parameters=second_tool_parameters,
)

# Verify error message
assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value)

# Verify only one tool was created
from extensions.ext_database import db

tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)

assert tool_count == 1

def test_create_workflow_tool_workflow_not_found_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app has no workflow.

This test verifies:
- Proper error handling for apps without workflows
- Correct error message
- No database changes when workflow is missing
"""
fake = Faker()

# Create test data but without workflow
app, account, _ = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)

# Remove workflow reference from app
from extensions.ext_database import db

app.workflow_id = None
db.session.commit()

# Attempt to create workflow tool for app without workflow
tool_parameters = self._create_test_workflow_tool_parameters()

with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=tool_parameters,
)

# Verify error message
assert f"Workflow not found for app {app.id}" in str(exc_info.value)

# Verify no workflow tool was created
tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)

assert tool_count == 0

def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful workflow tool update with valid parameters.

This test verifies:
- Proper workflow tool update with all required fields
- Correct database state after update
- Proper relationship maintenance
- External service integration
- Return value correctness
"""
fake = Faker()

# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)

# Create initial workflow tool
initial_tool_name = fake.word()
initial_tool_parameters = self._create_test_workflow_tool_parameters()

WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=initial_tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=initial_tool_parameters,
)

# Get the created tool
from extensions.ext_database import db

created_tool = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
)
.first()
)

# Setup update parameters
updated_tool_name = fake.word()
updated_tool_label = fake.word()
updated_tool_icon = {"type": "emoji", "emoji": "⚙️"}
updated_tool_description = fake.text(max_nb_chars=200)
updated_tool_parameters = self._create_test_workflow_tool_parameters()
updated_tool_privacy_policy = fake.text(max_nb_chars=100)
updated_tool_labels = ["automation", "updated"]

# Execute the update method
result = WorkflowToolManageService.update_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_tool_id=created_tool.id,
name=updated_tool_name,
label=updated_tool_label,
icon=updated_tool_icon,
description=updated_tool_description,
parameters=updated_tool_parameters,
privacy_policy=updated_tool_privacy_policy,
labels=updated_tool_labels,
)

# Verify the result
assert result == {"result": "success"}

# Verify database state was updated
db.session.refresh(created_tool)
assert created_tool.name == updated_tool_name
assert created_tool.label == updated_tool_label
assert created_tool.icon == json.dumps(updated_tool_icon)
assert created_tool.description == updated_tool_description
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters)
assert created_tool.privacy_policy == updated_tool_privacy_policy
assert created_tool.version == workflow.version
assert created_tool.updated_at is not None

# Verify external service calls
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called()
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called()
mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called()

def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test workflow tool update fails when tool does not exist.

This test verifies:
- Proper error handling for non-existent tools
- Correct error message
- No database changes when tool is invalid
"""
fake = Faker()

# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)

# Generate non-existent tool ID
non_existent_tool_id = fake.uuid4()

# Attempt to update non-existent workflow tool
tool_parameters = self._create_test_workflow_tool_parameters()

with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.update_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_tool_id=non_existent_tool_id, # Non-existent tool ID
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=tool_parameters,
)

# Verify error message
assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value)

# Verify no workflow tool was created
from extensions.ext_database import db

tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)

assert tool_count == 0

def test_update_workflow_tool_same_name_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool update succeeds when keeping the same name.

This test verifies:
- Proper handling when updating tool with same name
- Database state maintenance
- Update timestamp is set
"""
fake = Faker()

# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)

# Create first workflow tool
first_tool_name = fake.word()
first_tool_parameters = self._create_test_workflow_tool_parameters()

WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=first_tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=first_tool_parameters,
)

# Get the created tool
from extensions.ext_database import db

created_tool = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
)
.first()
)

# Attempt to update tool with same name (should not fail)
result = WorkflowToolManageService.update_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_tool_id=created_tool.id,
name=first_tool_name, # Same name
label=fake.word(),
icon={"type": "emoji", "emoji": "⚙️"},
description=fake.text(max_nb_chars=200),
parameters=first_tool_parameters,
)

# Verify update was successful
assert result == {"result": "success"}

# Verify tool still exists with the same name
db.session.refresh(created_tool)
assert created_tool.name == first_tool_name
assert created_tool.updated_at is not None

+ 0
- 6
dev/reformat Целия файл

@@ -14,11 +14,5 @@ uv run --directory api --dev ruff format ./
# run dotenv-linter linter
uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example

# run import-linter
uv run --directory api --dev lint-imports

# run ty check
dev/ty-check

# run mypy check
dev/mypy-check

+ 0
- 9
web/.husky/pre-commit Целия файл

@@ -41,15 +41,6 @@ if $api_modified; then
echo "Please run 'dev/reformat' to fix the fixable linting errors."
exit 1
fi

# run ty checks
uv run --directory api --dev ty check || status=$?
status=${status:-0}
if [ $status -ne 0 ]; then
echo "ty type checker on api module error, exit code: $status"
echo "Please run 'dev/ty-check' to check the type errors."
exit 1
fi
fi

if $web_modified; then

+ 1
- 1
web/app/components/workflow/nodes/_base/components/field.tsx Целия файл

@@ -38,7 +38,7 @@ const Field: FC<Props> = ({
<div className={cn(className, inline && 'flex w-full items-center justify-between')}>
<div
onClick={() => supportFold && toggleFold()}
className={cn('sticky top-0 z-10 flex items-center justify-between bg-components-panel-bg', supportFold && 'cursor-pointer')}>
className={cn('sticky top-0 flex items-center justify-between bg-components-panel-bg', supportFold && 'cursor-pointer')}>
<div className='flex h-6 items-center'>
<div className={cn(isSubTitle ? 'system-xs-medium-uppercase text-text-tertiary' : 'system-sm-semibold-uppercase text-text-secondary')}>
{title} {required && <span className='text-text-destructive'>*</span>}

Loading…
Отказ
Запис