### What problem does this PR solve? This PR introduces Pydantic-based validation for the list datasets HTTP API, improving code clarity and robustness. Key changes include: Pydantic Validation Error Handling Test Updates Documentation Updates ### Type of change - [x] Documentation Update - [x] Refactoringtags/v0.19.0
| @@ -32,12 +32,22 @@ from api.utils.api_utils import ( | |||
| deep_merge, | |||
| get_error_argument_result, | |||
| get_error_data_result, | |||
| get_error_operating_result, | |||
| get_error_permission_result, | |||
| get_parser_config, | |||
| get_result, | |||
| remap_dictionary_keys, | |||
| token_required, | |||
| verify_embedding_availability, | |||
| ) | |||
| from api.utils.validation_utils import CreateDatasetReq, DeleteDatasetReq, UpdateDatasetReq, validate_and_parse_json_request | |||
| from api.utils.validation_utils import ( | |||
| CreateDatasetReq, | |||
| DeleteDatasetReq, | |||
| ListDatasetReq, | |||
| UpdateDatasetReq, | |||
| validate_and_parse_json_request, | |||
| validate_and_parse_request_args, | |||
| ) | |||
| @manager.route("/datasets", methods=["POST"]) # noqa: F821 | |||
| @@ -113,7 +123,7 @@ def create(tenant_id): | |||
| try: | |||
| if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||
| return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") | |||
| return get_error_operating_result(message=f"Dataset name '{req['name']}' already exists") | |||
| except OperationalError as e: | |||
| logging.exception(e) | |||
| return get_error_data_result(message="Database operation failed") | |||
| @@ -126,7 +136,7 @@ def create(tenant_id): | |||
| try: | |||
| ok, t = TenantService.get_by_id(tenant_id) | |||
| if not ok: | |||
| return get_error_data_result(message="Tenant not found") | |||
| return get_error_permission_result(message="Tenant not found") | |||
| except OperationalError as e: | |||
| logging.exception(e) | |||
| return get_error_data_result(message="Database operation failed") | |||
| @@ -153,16 +163,7 @@ def create(tenant_id): | |||
| logging.exception(e) | |||
| return get_error_data_result(message="Database operation failed") | |||
| response_data = {} | |||
| key_mapping = { | |||
| "chunk_num": "chunk_count", | |||
| "doc_num": "document_count", | |||
| "parser_id": "chunk_method", | |||
| "embd_id": "embedding_model", | |||
| } | |||
| for key, value in k.to_dict().items(): | |||
| new_key = key_mapping.get(key, key) | |||
| response_data[new_key] = value | |||
| response_data = remap_dictionary_keys(k.to_dict()) | |||
| return get_result(data=response_data) | |||
| @@ -232,7 +233,7 @@ def delete(tenant_id): | |||
| logging.exception(e) | |||
| return get_error_data_result(message="Database operation failed") | |||
| if len(error_kb_ids) > 0: | |||
| return get_error_data_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") | |||
| return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") | |||
| errors = [] | |||
| success_count = 0 | |||
| @@ -347,7 +348,7 @@ def update(tenant_id, dataset_id): | |||
| try: | |||
| kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) | |||
| if kb is None: | |||
| return get_error_data_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") | |||
| return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") | |||
| except OperationalError as e: | |||
| logging.exception(e) | |||
| return get_error_data_result(message="Database operation failed") | |||
| @@ -418,7 +419,7 @@ def list_datasets(tenant_id): | |||
| name: page_size | |||
| type: integer | |||
| required: false | |||
| default: 1024 | |||
| default: 30 | |||
| description: Number of items per page. | |||
| - in: query | |||
| name: orderby | |||
| @@ -445,47 +446,46 @@ def list_datasets(tenant_id): | |||
| items: | |||
| type: object | |||
| """ | |||
| id = request.args.get("id") | |||
| name = request.args.get("name") | |||
| if id: | |||
| kbs = KnowledgebaseService.get_kb_by_id(id, tenant_id) | |||
| args, err = validate_and_parse_request_args(request, ListDatasetReq) | |||
| if err is not None: | |||
| return get_error_argument_result(err) | |||
| kb_id = request.args.get("id") | |||
| name = args.get("name") | |||
| if kb_id: | |||
| try: | |||
| kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id) | |||
| except OperationalError as e: | |||
| logging.exception(e) | |||
| return get_error_data_result(message="Database operation failed") | |||
| if not kbs: | |||
| return get_error_data_result(f"You don't own the dataset {id}") | |||
| return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'") | |||
| if name: | |||
| kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id) | |||
| try: | |||
| kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id) | |||
| except OperationalError as e: | |||
| logging.exception(e) | |||
| return get_error_data_result(message="Database operation failed") | |||
| if not kbs: | |||
| return get_error_data_result(f"You don't own the dataset {name}") | |||
| page_number = int(request.args.get("page", 1)) | |||
| items_per_page = int(request.args.get("page_size", 30)) | |||
| orderby = request.args.get("orderby", "create_time") | |||
| if request.args.get("desc", "false").lower() not in ["true", "false"]: | |||
| return get_error_data_result("desc should be true or false") | |||
| if request.args.get("desc", "true").lower() == "false": | |||
| desc = False | |||
| else: | |||
| desc = True | |||
| tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) | |||
| kbs = KnowledgebaseService.get_list( | |||
| [m["tenant_id"] for m in tenants], | |||
| tenant_id, | |||
| page_number, | |||
| items_per_page, | |||
| orderby, | |||
| desc, | |||
| id, | |||
| name, | |||
| ) | |||
| renamed_list = [] | |||
| return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'") | |||
| try: | |||
| tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) | |||
| kbs = KnowledgebaseService.get_list( | |||
| [m["tenant_id"] for m in tenants], | |||
| tenant_id, | |||
| args["page"], | |||
| args["page_size"], | |||
| args["orderby"], | |||
| args["desc"], | |||
| kb_id, | |||
| name, | |||
| ) | |||
| except OperationalError as e: | |||
| logging.exception(e) | |||
| return get_error_data_result(message="Database operation failed") | |||
| response_data_list = [] | |||
| for kb in kbs: | |||
| key_mapping = { | |||
| "chunk_num": "chunk_count", | |||
| "doc_num": "document_count", | |||
| "parser_id": "chunk_method", | |||
| "embd_id": "embedding_model", | |||
| } | |||
| renamed_data = {} | |||
| for key, value in kb.items(): | |||
| new_key = key_mapping.get(key, key) | |||
| renamed_data[new_key] = value | |||
| renamed_list.append(renamed_data) | |||
| return get_result(data=renamed_list) | |||
| response_data_list.append(remap_dictionary_keys(kb)) | |||
| return get_result(data=response_data_list) | |||
| @@ -329,6 +329,14 @@ def get_error_argument_result(message="Invalid arguments"): | |||
| return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message) | |||
| def get_error_permission_result(message="Permission error"): | |||
| return get_result(code=settings.RetCode.PERMISSION_ERROR, message=message) | |||
| def get_error_operating_result(message="Operating error"): | |||
| return get_result(code=settings.RetCode.OPERATING_ERROR, message=message) | |||
| def generate_confirmation_token(tenant_id): | |||
| serializer = URLSafeTimedSerializer(tenant_id) | |||
| return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34] | |||
| @@ -514,3 +522,38 @@ def deep_merge(default: dict, custom: dict) -> dict: | |||
| base_dict[key] = val | |||
| return merged | |||
| def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict: | |||
| """ | |||
| Transform dictionary keys using a configurable mapping schema. | |||
| Args: | |||
| source_data: Original dictionary to process | |||
| key_aliases: Custom key transformation rules (Optional) | |||
| When provided, overrides default key mapping | |||
| Format: {<original_key>: <new_key>, ...} | |||
| Returns: | |||
| dict: New dictionary with transformed keys preserving original values | |||
| Example: | |||
| >>> input_data = {"old_key": "value", "another_field": 42} | |||
| >>> remap_dictionary_keys(input_data, {"old_key": "new_key"}) | |||
| {'new_key': 'value', 'another_field': 42} | |||
| """ | |||
| DEFAULT_KEY_MAP = { | |||
| "chunk_num": "chunk_count", | |||
| "doc_num": "document_count", | |||
| "parser_id": "chunk_method", | |||
| "embd_id": "embedding_model", | |||
| } | |||
| transformed_data = {} | |||
| mapping = key_aliases or DEFAULT_KEY_MAP | |||
| for original_key, value in source_data.items(): | |||
| mapped_key = mapping.get(original_key, original_key) | |||
| transformed_data[mapped_key] = value | |||
| return transformed_data | |||
| @@ -13,13 +13,13 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import uuid | |||
| from collections import Counter | |||
| from enum import auto | |||
| from typing import Annotated, Any | |||
| from uuid import UUID | |||
| from flask import Request | |||
| from pydantic import UUID1, BaseModel, Field, StringConstraints, ValidationError, field_serializer, field_validator | |||
| from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator | |||
| from pydantic_core import PydanticCustomError | |||
| from strenum import StrEnum | |||
| from werkzeug.exceptions import BadRequest, UnsupportedMediaType | |||
| @@ -102,6 +102,71 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel] | |||
| return parsed_payload, None | |||
| def validate_and_parse_request_args(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None) -> tuple[dict[str, Any] | None, str | None]: | |||
| """ | |||
| Validates and parses request arguments against a Pydantic model. | |||
| This function performs a complete request validation workflow: | |||
| 1. Extracts query parameters from the request | |||
| 2. Merges with optional extra values (if provided) | |||
| 3. Validates against the specified Pydantic model | |||
| 4. Cleans the output by removing extra values | |||
| 5. Returns either parsed data or an error message | |||
| Args: | |||
| request (Request): Web framework request object containing query parameters | |||
| validator (type[BaseModel]): Pydantic model class for validation | |||
| extras (dict[str, Any] | None): Optional additional values to include in validation | |||
| but exclude from final output. Defaults to None. | |||
| Returns: | |||
| tuple[dict[str, Any] | None, str | None]: | |||
| - First element: Validated/parsed arguments as dict if successful, None otherwise | |||
| - Second element: Formatted error message if validation failed, None otherwise | |||
| Behavior: | |||
| - Query parameters are merged with extras before validation | |||
| - Extras are automatically removed from the final output | |||
| - All validation errors are formatted into a human-readable string | |||
| Raises: | |||
| TypeError: If validator is not a Pydantic BaseModel subclass | |||
| Examples: | |||
| Successful validation: | |||
| >>> validate_and_parse_request_args(request, MyValidator) | |||
| ({'param1': 'value'}, None) | |||
| Failed validation: | |||
| >>> validate_and_parse_request_args(request, MyValidator) | |||
| (None, "param1: Field required") | |||
| With extras: | |||
| >>> validate_and_parse_request_args(request, MyValidator, extras={'internal_id': 123}) | |||
| ({'param1': 'value'}, None) # internal_id removed from output | |||
| Notes: | |||
| - Uses request.args.to_dict() for Flask-compatible parameter extraction | |||
| - Maintains immutability of original request arguments | |||
| - Preserves type conversion from Pydantic validation | |||
| """ | |||
| args = request.args.to_dict(flat=True) | |||
| try: | |||
| if extras is not None: | |||
| args.update(extras) | |||
| validated_args = validator(**args) | |||
| except ValidationError as e: | |||
| return None, format_validation_error_message(e) | |||
| parsed_args = validated_args.model_dump() | |||
| if extras is not None: | |||
| for key in list(parsed_args.keys()): | |||
| if key in extras: | |||
| del parsed_args[key] | |||
| return parsed_args, None | |||
| def format_validation_error_message(e: ValidationError) -> str: | |||
| """ | |||
| Formats validation errors into a standardized string format. | |||
| @@ -143,6 +208,105 @@ def format_validation_error_message(e: ValidationError) -> str: | |||
| return "\n".join(error_messages) | |||
| def normalize_str(v: Any) -> Any: | |||
| """ | |||
| Normalizes string values to a standard format while preserving non-string inputs. | |||
| Performs the following transformations when input is a string: | |||
| 1. Trims leading/trailing whitespace (str.strip()) | |||
| 2. Converts to lowercase (str.lower()) | |||
| Non-string inputs are returned unchanged, making this function safe for mixed-type | |||
| processing pipelines. | |||
| Args: | |||
| v (Any): Input value to normalize. Accepts any Python object. | |||
| Returns: | |||
| Any: Normalized string if input was string-type, original value otherwise. | |||
| Behavior Examples: | |||
| String Input: " Admin " → "admin" | |||
| Empty String: " " → "" (empty string) | |||
| Non-String: | |||
| - 123 → 123 | |||
| - None → None | |||
| - ["User"] → ["User"] | |||
| Typical Use Cases: | |||
| - Standardizing user input | |||
| - Preparing data for case-insensitive comparison | |||
| - Cleaning API parameters | |||
| - Normalizing configuration values | |||
| Edge Cases: | |||
| - Unicode whitespace is handled by str.strip() | |||
| - Locale-independent lowercasing (str.lower()) | |||
| - Preserves falsy values (0, False, etc.) | |||
| Example: | |||
| >>> normalize_str(" ReadOnly ") | |||
| 'readonly' | |||
| >>> normalize_str(42) | |||
| 42 | |||
| """ | |||
| if isinstance(v, str): | |||
| stripped = v.strip() | |||
| normalized = stripped.lower() | |||
| return normalized | |||
| return v | |||
| def validate_uuid1_hex(v: Any) -> str: | |||
| """ | |||
| Validates and converts input to a UUID version 1 hexadecimal string. | |||
| This function performs strict validation and normalization: | |||
| 1. Accepts either UUID objects or UUID-formatted strings | |||
| 2. Verifies the UUID is version 1 (time-based) | |||
| 3. Returns the 32-character hexadecimal representation | |||
| Args: | |||
| v (Any): Input value to validate. Can be: | |||
| - UUID object (must be version 1) | |||
| - String in UUID format (e.g. "550e8400-e29b-41d4-a716-446655440000") | |||
| Returns: | |||
| str: 32-character lowercase hexadecimal string without hyphens | |||
| Example: "550e8400e29b41d4a716446655440000" | |||
| Raises: | |||
| PydanticCustomError: With code "invalid_UUID1_format" when: | |||
| - Input is not a UUID object or valid UUID string | |||
| - UUID version is not 1 | |||
| - String doesn't match UUID format | |||
| Examples: | |||
| Valid cases: | |||
| >>> validate_uuid1_hex("550e8400-e29b-41d4-a716-446655440000") | |||
| '550e8400e29b41d4a716446655440000' | |||
| >>> validate_uuid1_hex(UUID('550e8400-e29b-41d4-a716-446655440000')) | |||
| '550e8400e29b41d4a716446655440000' | |||
| Invalid cases: | |||
| >>> validate_uuid1_hex("not-a-uuid") # raises PydanticCustomError | |||
| >>> validate_uuid1_hex(12345) # raises PydanticCustomError | |||
| >>> validate_uuid1_hex(UUID(int=0)) # v4, raises PydanticCustomError | |||
| Notes: | |||
| - Uses Python's built-in UUID parser for format validation | |||
| - Version check prevents accidental use of other UUID versions | |||
| - Hyphens in input strings are automatically removed in output | |||
| """ | |||
| try: | |||
| uuid_obj = UUID(v) if isinstance(v, str) else v | |||
| if uuid_obj.version != 1: | |||
| raise PydanticCustomError("invalid_UUID1_format", "Must be a UUID1 format") | |||
| return uuid_obj.hex | |||
| except (AttributeError, ValueError, TypeError): | |||
| raise PydanticCustomError("invalid_UUID1_format", "Invalid UUID1 format") | |||
| class PermissionEnum(StrEnum): | |||
| me = auto() | |||
| team = auto() | |||
| @@ -217,8 +381,8 @@ class CreateDatasetReq(Base): | |||
| avatar: str | None = Field(default=None, max_length=65535) | |||
| description: str | None = Field(default=None, max_length=65535) | |||
| embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] | |||
| permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)] | |||
| chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")] | |||
| permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) | |||
| chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") | |||
| pagerank: int = Field(default=0, ge=0, le=100) | |||
| parser_config: ParserConfig | None = Field(default=None) | |||
| @@ -315,22 +479,8 @@ class CreateDatasetReq(Base): | |||
| @field_validator("permission", mode="before") | |||
| @classmethod | |||
| def permission_auto_lowercase(cls, v: Any) -> Any: | |||
| """ | |||
| Normalize permission input to lowercase for consistent PermissionEnum matching. | |||
| Args: | |||
| v (Any): Raw input value for the permission field | |||
| Returns: | |||
| Lowercase string if input is string type, otherwise returns original value | |||
| Behavior: | |||
| - Converts string inputs to lowercase (e.g., "ME" → "me") | |||
| - Non-string values pass through unchanged | |||
| - Works in validation pre-processing stage (before enum conversion) | |||
| """ | |||
| return v.lower() if isinstance(v, str) else v | |||
| def normalize_permission(cls, v: Any) -> Any: | |||
| return normalize_str(v) | |||
| @field_validator("parser_config", mode="before") | |||
| @classmethod | |||
| @@ -387,93 +537,117 @@ class CreateDatasetReq(Base): | |||
| class UpdateDatasetReq(CreateDatasetReq): | |||
| dataset_id: UUID1 = Field(...) | |||
| dataset_id: str = Field(...) | |||
| name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] | |||
| @field_serializer("dataset_id") | |||
| def serialize_uuid_to_hex(self, v: uuid.UUID) -> str: | |||
| """ | |||
| Serializes a UUID version 1 object to its hexadecimal string representation. | |||
| This field serializer specifically handles UUID version 1 objects, converting them | |||
| to their canonical 32-character hexadecimal format without hyphens. The conversion | |||
| is designed for consistent serialization in API responses and database storage. | |||
| Args: | |||
| v (uuid.UUID1): The UUID version 1 object to serialize. Must be a valid | |||
| UUID1 instance generated by Python's uuid module. | |||
| Returns: | |||
| str: 32-character lowercase hexadecimal string representation | |||
| Example: "550e8400e29b41d4a716446655440000" | |||
| Raises: | |||
| AttributeError: If input is not a proper UUID object (missing hex attribute) | |||
| TypeError: If input is not a UUID1 instance (when type checking is enabled) | |||
| Notes: | |||
| - Version 1 UUIDs contain timestamp and MAC address information | |||
| - The .hex property automatically converts to lowercase hexadecimal | |||
| - For cross-version compatibility, consider typing as uuid.UUID instead | |||
| """ | |||
| return v.hex | |||
| @field_validator("dataset_id", mode="before") | |||
| @classmethod | |||
| def validate_dataset_id(cls, v: Any) -> str: | |||
| return validate_uuid1_hex(v) | |||
| class DeleteReq(Base): | |||
| ids: list[UUID1] | None = Field(...) | |||
| ids: list[str] | None = Field(...) | |||
| @field_validator("ids", mode="after") | |||
| def check_duplicate_ids(cls, v: list[UUID1] | None) -> list[str] | None: | |||
| @classmethod | |||
| def validate_ids(cls, v_list: list[str] | None) -> list[str] | None: | |||
| """ | |||
| Validates and converts a list of UUID1 objects to hexadecimal strings while checking for duplicates. | |||
| This validator implements a three-stage processing pipeline: | |||
| 1. Null Handling - returns None for empty/null input | |||
| 2. UUID Conversion - transforms UUID objects to hex strings | |||
| 3. Duplicate Validation - ensures all IDs are unique | |||
| Validates and normalizes a list of UUID strings with None handling. | |||
| Behavior Specifications: | |||
| - Input: None → Returns None (indicates no operation) | |||
| - Input: [] → Returns [] (empty list for explicit no-op) | |||
| - Input: [UUID1,...] → Returns validated hex strings | |||
| - Duplicates: Raises formatted PydanticCustomError | |||
| This post-processing validator performs: | |||
| 1. None input handling (pass-through) | |||
| 2. UUID version 1 validation for each list item | |||
| 3. Duplicate value detection | |||
| 4. Returns normalized UUID hex strings or None | |||
| Args: | |||
| v (list[UUID1] | None): | |||
| - None: Indicates no datasets should be processed | |||
| - Empty list: Explicit empty operation | |||
| - Populated list: Dataset UUIDs to validate/convert | |||
| v_list (list[str] | None): Input list that has passed initial validation. | |||
| Either a list of UUID strings or None. | |||
| Returns: | |||
| list[str] | None: | |||
| - None when input is None | |||
| - List of 32-character hex strings (lowercase, no hyphens) | |||
| Example: ["550e8400e29b41d4a716446655440000"] | |||
| - None if input was None | |||
| - List of normalized UUID hex strings otherwise: | |||
| * 32-character lowercase | |||
| * Valid UUID version 1 | |||
| * Unique within list | |||
| Raises: | |||
| PydanticCustomError: When duplicates detected, containing: | |||
| - Error type: "duplicate_uuids" | |||
| - Template message: "Duplicate ids: '{duplicate_ids}'" | |||
| - Context: {"duplicate_ids": "id1, id2, ..."} | |||
| PydanticCustomError: With structured error details when: | |||
| - "invalid_UUID1_format": Any string fails UUIDv1 validation | |||
| - "duplicate_uuids": If duplicate IDs are detected | |||
| Example: | |||
| >>> validate([UUID("..."), UUID("...")]) | |||
| ["2cdf0456e9a711ee8000000000000000", ...] | |||
| Validation Rules: | |||
| - None input returns None | |||
| - Empty list returns empty list | |||
| - All non-None items must be valid UUIDv1 | |||
| - No duplicates permitted | |||
| - Original order preserved | |||
| >>> validate([UUID("..."), UUID("...")]) # Duplicates | |||
| PydanticCustomError: Duplicate ids: '2cdf0456e9a711ee8000000000000000' | |||
| Examples: | |||
| Valid cases: | |||
| >>> validate_ids(None) | |||
| None | |||
| >>> validate_ids([]) | |||
| [] | |||
| >>> validate_ids(["550e8400-e29b-41d4-a716-446655440000"]) | |||
| ["550e8400e29b41d4a716446655440000"] | |||
| Invalid cases: | |||
| >>> validate_ids(["invalid"]) | |||
| # raises PydanticCustomError(invalid_UUID1_format) | |||
| >>> validate_ids(["550e...", "550e..."]) | |||
| # raises PydanticCustomError(duplicate_uuids) | |||
| Security Notes: | |||
| - Validates UUID version to prevent version spoofing | |||
| - Duplicate check prevents data injection | |||
| - None handling maintains pipeline integrity | |||
| """ | |||
| if not v: | |||
| return v | |||
| if v_list is None: | |||
| return None | |||
| uuid_hex_list = [ids.hex for ids in v] | |||
| duplicates = [item for item, count in Counter(uuid_hex_list).items() if count > 1] | |||
| ids_list = [] | |||
| for v in v_list: | |||
| try: | |||
| ids_list.append(validate_uuid1_hex(v)) | |||
| except PydanticCustomError as e: | |||
| raise e | |||
| duplicates = [item for item, count in Counter(ids_list).items() if count > 1] | |||
| if duplicates: | |||
| duplicates_str = ", ".join(duplicates) | |||
| raise PydanticCustomError("duplicate_uuids", "Duplicate ids: '{duplicate_ids}'", {"duplicate_ids": duplicates_str}) | |||
| return uuid_hex_list | |||
| return ids_list | |||
| class DeleteDatasetReq(DeleteReq): ... | |||
| class OrderByEnum(StrEnum): | |||
| create_time = auto() | |||
| update_time = auto() | |||
| class BaseListReq(Base): | |||
| id: str | None = None | |||
| name: str | None = None | |||
| page: int = Field(default=1, ge=1) | |||
| page_size: int = Field(default=30, ge=1) | |||
| orderby: OrderByEnum = Field(default=OrderByEnum.create_time) | |||
| desc: bool = Field(default=True) | |||
| @field_validator("id", mode="before") | |||
| @classmethod | |||
| def validate_id(cls, v: Any) -> str: | |||
| return validate_uuid1_hex(v) | |||
| @field_validator("orderby", mode="before") | |||
| @classmethod | |||
| def normalize_orderby(cls, v: Any) -> Any: | |||
| return normalize_str(v) | |||
| class ListDatasetReq(BaseListReq): ... | |||
| @@ -122,7 +122,7 @@ class TestDatasetCreate: | |||
| assert res["code"] == 0, res | |||
| res = create_dataset(get_http_api_auth, payload) | |||
| assert res["code"] == 102, res | |||
| assert res["code"] == 103, res | |||
| assert res["message"] == f"Dataset name '{name}' already exists", res | |||
| @pytest.mark.p3 | |||
| @@ -134,7 +134,7 @@ class TestDatasetCreate: | |||
| payload = {"name": name.lower()} | |||
| res = create_dataset(get_http_api_auth, payload) | |||
| assert res["code"] == 102, res | |||
| assert res["code"] == 103, res | |||
| assert res["message"] == f"Dataset name '{name.lower()}' already exists", res | |||
| @pytest.mark.p2 | |||
| @@ -296,14 +296,15 @@ class TestDatasetCreate: | |||
| ("team", "team"), | |||
| ("me_upercase", "ME"), | |||
| ("team_upercase", "TEAM"), | |||
| ("whitespace", " ME "), | |||
| ], | |||
| ids=["me", "team", "me_upercase", "team_upercase"], | |||
| ids=["me", "team", "me_upercase", "team_upercase", "whitespace"], | |||
| ) | |||
| def test_permission(self, get_http_api_auth, name, permission): | |||
| payload = {"name": name, "permission": permission} | |||
| res = create_dataset(get_http_api_auth, payload) | |||
| assert res["code"] == 0, res | |||
| assert res["data"]["permission"] == permission.lower(), res | |||
| assert res["data"]["permission"] == permission.lower().strip(), res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| @@ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import uuid | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| @@ -40,8 +41,8 @@ class TestAuthorization: | |||
| ) | |||
| def test_auth_invalid(self, auth, expected_code, expected_message): | |||
| res = delete_datasets(auth) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| assert res["code"] == expected_code, res | |||
| assert res["message"] == expected_message, res | |||
| class TestRquest: | |||
| @@ -140,17 +141,25 @@ class TestDatasetsDelete: | |||
| payload = {"ids": ["not_uuid"]} | |||
| res = delete_datasets(get_http_api_auth, payload) | |||
| assert res["code"] == 101, res | |||
| assert "Input should be a valid UUID" in res["message"], res | |||
| assert "Invalid UUID1 format" in res["message"], res | |||
| res = list_datasets(get_http_api_auth) | |||
| assert len(res["data"]) == 1, res | |||
| @pytest.mark.p3 | |||
| @pytest.mark.usefixtures("add_dataset_func") | |||
| def test_id_not_uuid1(self, get_http_api_auth): | |||
| payload = {"ids": [uuid.uuid4().hex]} | |||
| res = delete_datasets(get_http_api_auth, payload) | |||
| assert res["code"] == 101, res | |||
| assert "Invalid UUID1 format" in res["message"], res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.usefixtures("add_dataset_func") | |||
| def test_id_wrong_uuid(self, get_http_api_auth): | |||
| payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]} | |||
| res = delete_datasets(get_http_api_auth, payload) | |||
| assert res["code"] == 102, res | |||
| assert res["code"] == 108, res | |||
| assert "lacks permission for dataset" in res["message"], res | |||
| res = list_datasets(get_http_api_auth) | |||
| @@ -170,7 +179,7 @@ class TestDatasetsDelete: | |||
| if callable(func): | |||
| payload = func(dataset_ids) | |||
| res = delete_datasets(get_http_api_auth, payload) | |||
| assert res["code"] == 102, res | |||
| assert res["code"] == 108, res | |||
| assert "lacks permission for dataset" in res["message"], res | |||
| res = list_datasets(get_http_api_auth) | |||
| @@ -195,7 +204,7 @@ class TestDatasetsDelete: | |||
| assert res["code"] == 0, res | |||
| res = delete_datasets(get_http_api_auth, payload) | |||
| assert res["code"] == 102, res | |||
| assert res["code"] == 108, res | |||
| assert "lacks permission for dataset" in res["message"], res | |||
| @pytest.mark.p2 | |||
| @@ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import uuid | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| @@ -21,8 +22,8 @@ from libs.auth import RAGFlowHttpApiAuth | |||
| from libs.utils import is_sorted | |||
| @pytest.mark.p1 | |||
| class TestAuthorization: | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "auth, expected_code, expected_message", | |||
| [ | |||
| @@ -34,269 +35,305 @@ class TestAuthorization: | |||
| ), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||
| def test_auth_invalid(self, auth, expected_code, expected_message): | |||
| res = list_datasets(auth) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| assert res["code"] == expected_code, res | |||
| assert res["message"] == expected_message, res | |||
| class TestCapability: | |||
| @pytest.mark.p3 | |||
| def test_concurrent_list(self, get_http_api_auth): | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses), responses | |||
| @pytest.mark.usefixtures("add_datasets") | |||
| class TestDatasetsList: | |||
| @pytest.mark.p1 | |||
| def test_default(self, get_http_api_auth): | |||
| res = list_datasets(get_http_api_auth, params={}) | |||
| def test_params_unset(self, get_http_api_auth): | |||
| res = list_datasets(get_http_api_auth, None) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]) == 5, res | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]) == 5 | |||
| @pytest.mark.p2 | |||
| def test_params_empty(self, get_http_api_auth): | |||
| res = list_datasets(get_http_api_auth, {}) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]) == 5, res | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_code, expected_page_size, expected_message", | |||
| "params, expected_page_size", | |||
| [ | |||
| ({"page": None, "page_size": 2}, 0, 2, ""), | |||
| ({"page": 0, "page_size": 2}, 0, 2, ""), | |||
| ({"page": 2, "page_size": 2}, 0, 2, ""), | |||
| ({"page": 3, "page_size": 2}, 0, 1, ""), | |||
| ({"page": "3", "page_size": 2}, 0, 1, ""), | |||
| pytest.param( | |||
| {"page": -1, "page_size": 2}, | |||
| 100, | |||
| 0, | |||
| "1064", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| pytest.param( | |||
| {"page": "a", "page_size": 2}, | |||
| 100, | |||
| 0, | |||
| """ValueError("invalid literal for int() with base 10: \'a\'")""", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| ({"page": 2, "page_size": 2}, 2), | |||
| ({"page": 3, "page_size": 2}, 1), | |||
| ({"page": 4, "page_size": 2}, 0), | |||
| ({"page": "2", "page_size": 2}, 2), | |||
| ({"page": 1, "page_size": 10}, 5), | |||
| ], | |||
| ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "string_page_number", "full_data_single_page"], | |||
| ) | |||
| def test_page(self, get_http_api_auth, params, expected_page_size): | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]) == expected_page_size, res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_code, expected_message", | |||
| [ | |||
| ({"page": 0}, 101, "Input should be greater than or equal to 1"), | |||
| ({"page": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"), | |||
| ], | |||
| ids=["page_0", "page_a"], | |||
| ) | |||
| def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message): | |||
| def test_page_invalid(self, get_http_api_auth, params, expected_code, expected_message): | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| assert len(res["data"]) == expected_page_size | |||
| else: | |||
| assert res["message"] == expected_message | |||
| assert res["code"] == expected_code, res | |||
| assert expected_message in res["message"], res | |||
| @pytest.mark.p2 | |||
| def test_page_none(self, get_http_api_auth): | |||
| params = {"page": None} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]) == 5, res | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_code, expected_page_size, expected_message", | |||
| "params, expected_page_size", | |||
| [ | |||
| ({"page_size": None}, 0, 5, ""), | |||
| ({"page_size": 0}, 0, 0, ""), | |||
| ({"page_size": 1}, 0, 1, ""), | |||
| ({"page_size": 6}, 0, 5, ""), | |||
| ({"page_size": "1"}, 0, 1, ""), | |||
| pytest.param( | |||
| {"page_size": -1}, | |||
| 100, | |||
| 0, | |||
| "1064", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| pytest.param( | |||
| {"page_size": "a"}, | |||
| 100, | |||
| 0, | |||
| """ValueError("invalid literal for int() with base 10: \'a\'")""", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| ({"page_size": 1}, 1), | |||
| ({"page_size": 3}, 3), | |||
| ({"page_size": 5}, 5), | |||
| ({"page_size": 6}, 5), | |||
| ({"page_size": "1"}, 1), | |||
| ], | |||
| ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total", "string_type_page_size"], | |||
| ) | |||
| def test_page_size( | |||
| self, | |||
| get_http_api_auth, | |||
| params, | |||
| expected_code, | |||
| expected_page_size, | |||
| expected_message, | |||
| ): | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| assert len(res["data"]) == expected_page_size | |||
| else: | |||
| assert res["message"] == expected_message | |||
| def test_page_size(self, get_http_api_auth, params, expected_page_size): | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]) == expected_page_size, res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_code, expected_message", | |||
| [ | |||
| ({"page_size": 0}, 101, "Input should be greater than or equal to 1"), | |||
| ({"page_size": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"), | |||
| ], | |||
| ) | |||
| def test_page_size_invalid(self, get_http_api_auth, params, expected_code, expected_message): | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == expected_code, res | |||
| assert expected_message in res["message"], res | |||
| @pytest.mark.p2 | |||
| def test_page_size_none(self, get_http_api_auth): | |||
| params = {"page_size": None} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]) == 5, res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "params, assertions", | |||
| [ | |||
| ({"orderby": "create_time"}, lambda r: (is_sorted(r["data"], "create_time", True))), | |||
| ({"orderby": "update_time"}, lambda r: (is_sorted(r["data"], "update_time", True))), | |||
| ({"orderby": "CREATE_TIME"}, lambda r: (is_sorted(r["data"], "create_time", True))), | |||
| ({"orderby": "UPDATE_TIME"}, lambda r: (is_sorted(r["data"], "update_time", True))), | |||
| ({"orderby": " create_time "}, lambda r: (is_sorted(r["data"], "update_time", True))), | |||
| ], | |||
| ids=["orderby_create_time", "orderby_update_time", "orderby_create_time_upper", "orderby_update_time_upper", "whitespace"], | |||
| ) | |||
| def test_orderby(self, get_http_api_auth, params, assertions): | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| if callable(assertions): | |||
| assert assertions(res), res | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_code, assertions, expected_message", | |||
| "params", | |||
| [ | |||
| ({"orderby": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), | |||
| ({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), | |||
| ({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", True)), ""), | |||
| pytest.param( | |||
| {"orderby": "name", "desc": "False"}, | |||
| 0, | |||
| lambda r: (is_sorted(r["data"]["docs"], "name", False)), | |||
| "", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| pytest.param( | |||
| {"orderby": "unknown"}, | |||
| 102, | |||
| 0, | |||
| "orderby should be create_time or update_time", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| {"orderby": ""}, | |||
| {"orderby": "unknown"}, | |||
| ], | |||
| ids=["empty", "unknown"], | |||
| ) | |||
| def test_orderby( | |||
| self, | |||
| get_http_api_auth, | |||
| params, | |||
| expected_code, | |||
| assertions, | |||
| expected_message, | |||
| ): | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| if callable(assertions): | |||
| assert assertions(res) | |||
| else: | |||
| assert res["message"] == expected_message | |||
| def test_orderby_invalid(self, get_http_api_auth, params): | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 101, res | |||
| assert "Input should be 'create_time' or 'update_time'" in res["message"], res | |||
| @pytest.mark.p3 | |||
| def test_orderby_none(self, get_http_api_auth): | |||
| params = {"order_by": None} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| assert is_sorted(res["data"], "create_time", True), res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_code, assertions, expected_message", | |||
| "params, assertions", | |||
| [ | |||
| ({"desc": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), | |||
| ({"desc": "true"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), | |||
| ({"desc": "True"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), | |||
| ({"desc": True}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""), | |||
| ({"desc": "false"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), | |||
| ({"desc": "False"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), | |||
| ({"desc": False}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""), | |||
| ({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", False)), ""), | |||
| pytest.param( | |||
| {"desc": "unknown"}, | |||
| 102, | |||
| 0, | |||
| "desc should be true or false", | |||
| marks=pytest.mark.skip(reason="issues/5851"), | |||
| ), | |||
| ({"desc": True}, lambda r: (is_sorted(r["data"], "create_time", True))), | |||
| ({"desc": False}, lambda r: (is_sorted(r["data"], "create_time", False))), | |||
| ({"desc": "true"}, lambda r: (is_sorted(r["data"], "create_time", True))), | |||
| ({"desc": "false"}, lambda r: (is_sorted(r["data"], "create_time", False))), | |||
| ({"desc": 1}, lambda r: (is_sorted(r["data"], "create_time", True))), | |||
| ({"desc": 0}, lambda r: (is_sorted(r["data"], "create_time", False))), | |||
| ({"desc": "yes"}, lambda r: (is_sorted(r["data"], "create_time", True))), | |||
| ({"desc": "no"}, lambda r: (is_sorted(r["data"], "create_time", False))), | |||
| ({"desc": "y"}, lambda r: (is_sorted(r["data"], "create_time", True))), | |||
| ({"desc": "n"}, lambda r: (is_sorted(r["data"], "create_time", False))), | |||
| ], | |||
| ids=["desc=True", "desc=False", "desc=true", "desc=false", "desc=1", "desc=0", "desc=yes", "desc=no", "desc=y", "desc=n"], | |||
| ) | |||
| def test_desc( | |||
| self, | |||
| get_http_api_auth, | |||
| params, | |||
| expected_code, | |||
| assertions, | |||
| expected_message, | |||
| ): | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| if callable(assertions): | |||
| assert assertions(res) | |||
| else: | |||
| assert res["message"] == expected_message | |||
| def test_desc(self, get_http_api_auth, params, assertions): | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| if callable(assertions): | |||
| assert assertions(res), res | |||
| @pytest.mark.p1 | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_code, expected_num, expected_message", | |||
| "params", | |||
| [ | |||
| ({"name": None}, 0, 5, ""), | |||
| ({"name": ""}, 0, 5, ""), | |||
| ({"name": "dataset_1"}, 0, 1, ""), | |||
| ({"name": "unknown"}, 102, 0, "You don't own the dataset unknown"), | |||
| {"desc": 3.14}, | |||
| {"desc": "unknown"}, | |||
| ], | |||
| ids=["empty", "unknown"], | |||
| ) | |||
| def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message): | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| if params["name"] in [None, ""]: | |||
| assert len(res["data"]) == expected_num | |||
| else: | |||
| assert res["data"][0]["name"] == params["name"] | |||
| else: | |||
| assert res["message"] == expected_message | |||
| def test_desc_invalid(self, get_http_api_auth, params): | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 101, res | |||
| assert "Input should be a valid boolean, unable to interpret input" in res["message"], res | |||
| @pytest.mark.p3 | |||
| def test_desc_none(self, get_http_api_auth): | |||
| params = {"desc": None} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| assert is_sorted(res["data"], "create_time", True), res | |||
| @pytest.mark.p1 | |||
| def test_name(self, get_http_api_auth): | |||
| params = {"name": "dataset_1"} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]) == 1, res | |||
| assert res["data"][0]["name"] == "dataset_1", res | |||
| @pytest.mark.p2 | |||
| def test_name_wrong(self, get_http_api_auth): | |||
| params = {"name": "wrong name"} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 108, res | |||
| assert "lacks permission for dataset" in res["message"], res | |||
| @pytest.mark.p2 | |||
| def test_name_empty(self, get_http_api_auth): | |||
| params = {"name": ""} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]) == 5, res | |||
| @pytest.mark.p2 | |||
| def test_name_none(self, get_http_api_auth): | |||
| params = {"name": None} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]) == 5, res | |||
| @pytest.mark.p1 | |||
| def test_id(self, get_http_api_auth, add_datasets): | |||
| dataset_ids = add_datasets | |||
| params = {"id": dataset_ids[0]} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]) == 1 | |||
| assert res["data"][0]["id"] == dataset_ids[0] | |||
| @pytest.mark.p2 | |||
| def test_id_not_uuid(self, get_http_api_auth): | |||
| params = {"id": "not_uuid"} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 101, res | |||
| assert "Invalid UUID1 format" in res["message"], res | |||
| @pytest.mark.p2 | |||
| def test_id_not_uuid1(self, get_http_api_auth): | |||
| params = {"id": uuid.uuid4().hex} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 101, res | |||
| assert "Invalid UUID1 format" in res["message"], res | |||
| @pytest.mark.p2 | |||
| def test_id_wrong_uuid(self, get_http_api_auth): | |||
| params = {"id": "d94a8dc02c9711f0930f7fbc369eab6d"} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 108, res | |||
| assert "lacks permission for dataset" in res["message"], res | |||
| @pytest.mark.p2 | |||
| def test_id_empty(self, get_http_api_auth): | |||
| params = {"id": ""} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 101, res | |||
| assert "Invalid UUID1 format" in res["message"], res | |||
| @pytest.mark.p2 | |||
| def test_id_none(self, get_http_api_auth): | |||
| params = {"id": None} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]) == 5, res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "dataset_id, expected_code, expected_num, expected_message", | |||
| "func, name, expected_num", | |||
| [ | |||
| (None, 0, 5, ""), | |||
| ("", 0, 5, ""), | |||
| (lambda r: r[0], 0, 1, ""), | |||
| ("unknown", 102, 0, "You don't own the dataset unknown"), | |||
| (lambda r: r[0], "dataset_0", 1), | |||
| (lambda r: r[0], "dataset_1", 0), | |||
| ], | |||
| ids=["name_and_id_match", "name_and_id_mismatch"], | |||
| ) | |||
| def test_id( | |||
| self, | |||
| get_http_api_auth, | |||
| add_datasets, | |||
| dataset_id, | |||
| expected_code, | |||
| expected_num, | |||
| expected_message, | |||
| ): | |||
| def test_name_and_id(self, get_http_api_auth, add_datasets, func, name, expected_num): | |||
| dataset_ids = add_datasets | |||
| if callable(dataset_id): | |||
| params = {"id": dataset_id(dataset_ids)} | |||
| else: | |||
| params = {"id": dataset_id} | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| if params["id"] in [None, ""]: | |||
| assert len(res["data"]) == expected_num | |||
| else: | |||
| assert res["data"][0]["id"] == params["id"] | |||
| else: | |||
| assert res["message"] == expected_message | |||
| if callable(func): | |||
| params = {"id": func(dataset_ids), "name": name} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]) == expected_num, res | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "dataset_id, name, expected_code, expected_num, expected_message", | |||
| "dataset_id, name", | |||
| [ | |||
| (lambda r: r[0], "dataset_0", 0, 1, ""), | |||
| (lambda r: r[0], "dataset_1", 0, 0, ""), | |||
| (lambda r: r[0], "unknown", 102, 0, "You don't own the dataset unknown"), | |||
| ("id", "dataset_0", 102, 0, "You don't own the dataset id"), | |||
| (lambda r: r[0], "wrong_name"), | |||
| (uuid.uuid1().hex, "dataset_0"), | |||
| ], | |||
| ids=["name", "id"], | |||
| ) | |||
| def test_name_and_id( | |||
| self, | |||
| get_http_api_auth, | |||
| add_datasets, | |||
| dataset_id, | |||
| name, | |||
| expected_code, | |||
| expected_num, | |||
| expected_message, | |||
| ): | |||
| def test_name_and_id_wrong(self, get_http_api_auth, add_datasets, dataset_id, name): | |||
| dataset_ids = add_datasets | |||
| if callable(dataset_id): | |||
| params = {"id": dataset_id(dataset_ids), "name": name} | |||
| else: | |||
| params = {"id": dataset_id, "name": name} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 108, res | |||
| assert "lacks permission for dataset" in res["message"], res | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == expected_code | |||
| if expected_code == 0: | |||
| assert len(res["data"]) == expected_num | |||
| else: | |||
| assert res["message"] == expected_message | |||
| @pytest.mark.p3 | |||
| def test_concurrent_list(self, get_http_api_auth): | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| @pytest.mark.p3 | |||
| def test_invalid_params(self, get_http_api_auth): | |||
| params = {"a": "b"} | |||
| res = list_datasets(get_http_api_auth, params=params) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]) == 5 | |||
| @pytest.mark.p2 | |||
| def test_field_unsupported(self, get_http_api_auth): | |||
| params = {"unknown_field": "unknown_field"} | |||
| res = list_datasets(get_http_api_auth, params) | |||
| assert res["code"] == 101, res | |||
| assert "Extra inputs are not permitted" in res["message"], res | |||
| @@ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import uuid | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| import pytest | |||
| @@ -98,16 +99,23 @@ class TestCapability: | |||
| class TestDatasetUpdate: | |||
| @pytest.mark.p3 | |||
| def test_dataset_id_not_uuid(self, get_http_api_auth): | |||
| payload = {"name": "not_uuid"} | |||
| payload = {"name": "not uuid"} | |||
| res = update_dataset(get_http_api_auth, "not_uuid", payload) | |||
| assert res["code"] == 101, res | |||
| assert "Input should be a valid UUID" in res["message"], res | |||
| assert "Invalid UUID1 format" in res["message"], res | |||
| @pytest.mark.p3 | |||
| def test_dataset_id_not_uuid1(self, get_http_api_auth): | |||
| payload = {"name": "not uuid1"} | |||
| res = update_dataset(get_http_api_auth, uuid.uuid4().hex, payload) | |||
| assert res["code"] == 101, res | |||
| assert "Invalid UUID1 format" in res["message"], res | |||
| @pytest.mark.p3 | |||
| def test_dataset_id_wrong_uuid(self, get_http_api_auth): | |||
| payload = {"name": "wrong_uuid"} | |||
| payload = {"name": "wrong uuid"} | |||
| res = update_dataset(get_http_api_auth, "d94a8dc02c9711f0930f7fbc369eab6d", payload) | |||
| assert res["code"] == 102, res | |||
| assert res["code"] == 108, res | |||
| assert "lacks permission for dataset" in res["message"], res | |||
| @pytest.mark.p1 | |||
| @@ -322,8 +330,9 @@ class TestDatasetUpdate: | |||
| "team", | |||
| "ME", | |||
| "TEAM", | |||
| " ME ", | |||
| ], | |||
| ids=["me", "team", "me_upercase", "team_upercase"], | |||
| ids=["me", "team", "me_upercase", "team_upercase", "whitespace"], | |||
| ) | |||
| def test_permission(self, get_http_api_auth, add_dataset_func, permission): | |||
| dataset_id = add_dataset_func | |||
| @@ -333,7 +342,7 @@ class TestDatasetUpdate: | |||
| res = list_datasets(get_http_api_auth) | |||
| assert res["code"] == 0, res | |||
| assert res["data"][0]["permission"] == permission.lower(), res | |||
| assert res["data"][0]["permission"] == permission.lower().strip(), res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| @@ -734,7 +743,6 @@ class TestDatasetUpdate: | |||
| assert res["code"] == 0, res | |||
| res = list_datasets(get_http_api_auth) | |||
| print(res) | |||
| assert res["code"] == 0, res | |||
| assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res | |||
| @@ -757,7 +765,6 @@ class TestDatasetUpdate: | |||
| assert res["code"] == 0, res | |||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||
| print(res) | |||
| assert res["code"] == 0, res | |||
| assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res | |||