### What problem does this PR solve? This PR introduces Pydantic-based validation for the update dataset HTTP API, improving code clarity and robustness. Key changes include: 1. Pydantic Validation 2. Error Handling 3. Test Updates 4. Documentation Updates 5. fix bug: #5915 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Documentation Update - [x] Refactoringtags/v0.19.0
| from api.db.services.file2document_service import File2DocumentService | from api.db.services.file2document_service import File2DocumentService | ||||
| from api.db.services.file_service import FileService | from api.db.services.file_service import FileService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMService, TenantLLMService | |||||
| from api.db.services.user_service import TenantService | from api.db.services.user_service import TenantService | ||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.utils.api_utils import ( | from api.utils.api_utils import ( | ||||
| check_duplicate_ids, | check_duplicate_ids, | ||||
| dataset_readonly_fields, | |||||
| deep_merge, | |||||
| get_error_argument_result, | get_error_argument_result, | ||||
| get_error_data_result, | get_error_data_result, | ||||
| get_parser_config, | get_parser_config, | ||||
| get_result, | get_result, | ||||
| token_required, | token_required, | ||||
| valid, | |||||
| valid_parser_config, | |||||
| verify_embedding_availability, | verify_embedding_availability, | ||||
| ) | ) | ||||
| from api.utils.validation_utils import CreateDatasetReq, validate_and_parse_json_request | |||||
| from api.utils.validation_utils import CreateDatasetReq, UpdateDatasetReq, validate_and_parse_json_request | |||||
| @manager.route("/datasets", methods=["POST"]) # noqa: F821 | @manager.route("/datasets", methods=["POST"]) # noqa: F821 | ||||
| return get_error_argument_result(err) | return get_error_argument_result(err) | ||||
| try: | try: | ||||
| if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||||
| return get_error_argument_result(message=f"Dataset name '{req['name']}' already exists") | |||||
| if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||||
| return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") | |||||
| except OperationalError as e: | except OperationalError as e: | ||||
| logging.exception(e) | logging.exception(e) | ||||
| return get_error_data_result(message="Database operation failed") | return get_error_data_result(message="Database operation failed") | ||||
| try: | try: | ||||
| if not KnowledgebaseService.save(**req): | if not KnowledgebaseService.save(**req): | ||||
| return get_error_data_result(message="Database operation failed") | |||||
| return get_error_data_result(message="Create dataset error.(Database error)") | |||||
| except OperationalError as e: | except OperationalError as e: | ||||
| logging.exception(e) | logging.exception(e) | ||||
| return get_error_data_result(message="Database operation failed") | return get_error_data_result(message="Database operation failed") | ||||
| schema: | schema: | ||||
| type: object | type: object | ||||
| """ | """ | ||||
| errors = [] | errors = [] | ||||
| success_count = 0 | success_count = 0 | ||||
| req = request.json | req = request.json | ||||
| name: | name: | ||||
| type: string | type: string | ||||
| description: New name of the dataset. | description: New name of the dataset. | ||||
| avatar: | |||||
| type: string | |||||
| description: Updated base64 encoding of the avatar. | |||||
| description: | |||||
| type: string | |||||
| description: Updated description of the dataset. | |||||
| embedding_model: | |||||
| type: string | |||||
| description: Updated embedding model Name. | |||||
| permission: | permission: | ||||
| type: string | type: string | ||||
| enum: ['me', 'team'] | enum: ['me', 'team'] | ||||
| description: Updated permission. | |||||
| description: Updated dataset permission. | |||||
| chunk_method: | chunk_method: | ||||
| type: string | type: string | ||||
| enum: ["naive", "manual", "qa", "table", "paper", "book", "laws", | |||||
| "presentation", "picture", "one", "email", "tag" | |||||
| enum: ["naive", "book", "email", "laws", "manual", "one", "paper", | |||||
| "picture", "presentation", "qa", "table", "tag" | |||||
| ] | ] | ||||
| description: Updated chunking method. | description: Updated chunking method. | ||||
| pagerank: | |||||
| type: integer | |||||
| description: Updated page rank. | |||||
| parser_config: | parser_config: | ||||
| type: object | type: object | ||||
| description: Updated parser configuration. | description: Updated parser configuration. | ||||
| schema: | schema: | ||||
| type: object | type: object | ||||
| """ | """ | ||||
| if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): | |||||
| return get_error_data_result(message="You don't own the dataset") | |||||
| req = request.json | |||||
| for k in req.keys(): | |||||
| if dataset_readonly_fields(k): | |||||
| return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"'{k}' is readonly.") | |||||
| e, t = TenantService.get_by_id(tenant_id) | |||||
| invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id", "create_date", "create_time", "created_by", "status", "token_num", "update_date", "update_time"} | |||||
| if any(key in req for key in invalid_keys): | |||||
| return get_error_data_result(message="The input parameters are invalid.") | |||||
| permission = req.get("permission") | |||||
| chunk_method = req.get("chunk_method") | |||||
| parser_config = req.get("parser_config") | |||||
| valid_parser_config(parser_config) | |||||
| valid_permission = ["me", "team"] | |||||
| valid_chunk_method = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "email", "tag"] | |||||
| check_validation = valid( | |||||
| permission, | |||||
| valid_permission, | |||||
| chunk_method, | |||||
| valid_chunk_method, | |||||
| ) | |||||
| if check_validation: | |||||
| return check_validation | |||||
| if "tenant_id" in req: | |||||
| if req["tenant_id"] != tenant_id: | |||||
| return get_error_data_result(message="Can't change `tenant_id`.") | |||||
| e, kb = KnowledgebaseService.get_by_id(dataset_id) | |||||
| if "parser_config" in req: | |||||
| temp_dict = kb.parser_config | |||||
| temp_dict.update(req["parser_config"]) | |||||
| req["parser_config"] = temp_dict | |||||
| if "chunk_count" in req: | |||||
| if req["chunk_count"] != kb.chunk_num: | |||||
| return get_error_data_result(message="Can't change `chunk_count`.") | |||||
| req.pop("chunk_count") | |||||
| if "document_count" in req: | |||||
| if req["document_count"] != kb.doc_num: | |||||
| return get_error_data_result(message="Can't change `document_count`.") | |||||
| req.pop("document_count") | |||||
| if req.get("chunk_method"): | |||||
| if kb.chunk_num != 0 and req["chunk_method"] != kb.parser_id: | |||||
| return get_error_data_result(message="If `chunk_count` is not 0, `chunk_method` is not changeable.") | |||||
| req["parser_id"] = req.pop("chunk_method") | |||||
| if req["parser_id"] != kb.parser_id: | |||||
| if not req.get("parser_config"): | |||||
| req["parser_config"] = get_parser_config(chunk_method, parser_config) | |||||
| if "embedding_model" in req: | |||||
| if kb.chunk_num != 0 and req["embedding_model"] != kb.embd_id: | |||||
| return get_error_data_result(message="If `chunk_count` is not 0, `embedding_model` is not changeable.") | |||||
| if not req.get("embedding_model"): | |||||
| return get_error_data_result("`embedding_model` can't be empty") | |||||
| valid_embedding_models = [ | |||||
| "BAAI/bge-large-zh-v1.5", | |||||
| "BAAI/bge-base-en-v1.5", | |||||
| "BAAI/bge-large-en-v1.5", | |||||
| "BAAI/bge-small-en-v1.5", | |||||
| "BAAI/bge-small-zh-v1.5", | |||||
| "jinaai/jina-embeddings-v2-base-en", | |||||
| "jinaai/jina-embeddings-v2-small-en", | |||||
| "nomic-ai/nomic-embed-text-v1.5", | |||||
| "sentence-transformers/all-MiniLM-L6-v2", | |||||
| "text-embedding-v2", | |||||
| "text-embedding-v3", | |||||
| "maidalun1020/bce-embedding-base_v1", | |||||
| ] | |||||
| embd_model = LLMService.query(llm_name=req["embedding_model"], model_type="embedding") | |||||
| if embd_model: | |||||
| if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query( | |||||
| tenant_id=tenant_id, | |||||
| model_type="embedding", | |||||
| llm_name=req.get("embedding_model"), | |||||
| ): | |||||
| return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") | |||||
| if not embd_model: | |||||
| embd_model = TenantLLMService.query(tenant_id=tenant_id, model_type="embedding", llm_name=req.get("embedding_model")) | |||||
| # Field name transformations during model dump: | |||||
| # | Original | Dump Output | | |||||
| # |----------------|-------------| | |||||
| # | embedding_model| embd_id | | |||||
| # | chunk_method | parser_id | | |||||
| extras = {"dataset_id": dataset_id} | |||||
| req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True) | |||||
| if err is not None: | |||||
| return get_error_argument_result(err) | |||||
| if not req: | |||||
| return get_error_argument_result(message="No properties were modified") | |||||
| try: | |||||
| kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) | |||||
| if kb is None: | |||||
| return get_error_data_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") | |||||
| except OperationalError as e: | |||||
| logging.exception(e) | |||||
| return get_error_data_result(message="Database operation failed") | |||||
| if req.get("parser_config"): | |||||
| req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"]) | |||||
| if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id and req.get("parser_config") is None: | |||||
| req["parser_config"] = get_parser_config(chunk_method, None) | |||||
| if "name" in req and req["name"].lower() != kb.name.lower(): | |||||
| try: | |||||
| exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value) | |||||
| if exists: | |||||
| return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") | |||||
| except OperationalError as e: | |||||
| logging.exception(e) | |||||
| return get_error_data_result(message="Database operation failed") | |||||
| if "embd_id" in req: | |||||
| if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: | |||||
| return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") | |||||
| ok, err = verify_embedding_availability(req["embd_id"], tenant_id) | |||||
| if not ok: | |||||
| return err | |||||
| try: | |||||
| if not KnowledgebaseService.update_by_id(kb.id, req): | |||||
| return get_error_data_result(message="Update dataset error.(Database error)") | |||||
| except OperationalError as e: | |||||
| logging.exception(e) | |||||
| return get_error_data_result(message="Database operation failed") | |||||
| if not embd_model: | |||||
| return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") | |||||
| req["embd_id"] = req.pop("embedding_model") | |||||
| if "name" in req: | |||||
| req["name"] = req["name"].strip() | |||||
| if len(req["name"]) >= 128: | |||||
| return get_error_data_result(message="Dataset name should not be longer than 128 characters.") | |||||
| if req["name"].lower() != kb.name.lower() and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: | |||||
| return get_error_data_result(message="Duplicated dataset name in updating dataset.") | |||||
| flds = list(req.keys()) | |||||
| for f in flds: | |||||
| if req[f] == "" and f in ["permission", "parser_id", "chunk_method"]: | |||||
| del req[f] | |||||
| if not KnowledgebaseService.update_by_id(kb.id, req): | |||||
| return get_error_data_result(message="Update dataset error.(Database error)") | |||||
| return get_result(code=settings.RetCode.SUCCESS) | return get_result(code=settings.RetCode.SUCCESS) | ||||
| import random | import random | ||||
| import time | import time | ||||
| from base64 import b64encode | from base64 import b64encode | ||||
| from copy import deepcopy | |||||
| from functools import wraps | from functools import wraps | ||||
| from hmac import HMAC | from hmac import HMAC | ||||
| from io import BytesIO | from io import BytesIO | ||||
| return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34] | return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34] | ||||
| def valid(permission, valid_permission, chunk_method, valid_chunk_method): | |||||
| if valid_parameter(permission, valid_permission): | |||||
| return valid_parameter(permission, valid_permission) | |||||
| if valid_parameter(chunk_method, valid_chunk_method): | |||||
| return valid_parameter(chunk_method, valid_chunk_method) | |||||
| def valid_parameter(parameter, valid_values): | |||||
| if parameter and parameter not in valid_values: | |||||
| return get_error_data_result(f"'{parameter}' is not in {valid_values}") | |||||
| def dataset_readonly_fields(field_name): | |||||
| return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", "created_by", "document_count", "token_num", "status", "tenant_id", "id"] | |||||
| def get_parser_config(chunk_method, parser_config): | def get_parser_config(chunk_method, parser_config): | ||||
| if parser_config: | if parser_config: | ||||
| return parser_config | return parser_config | ||||
| } | } | ||||
| def valid_parser_config(parser_config): | |||||
| if not parser_config: | |||||
| return | |||||
| scopes = set( | |||||
| [ | |||||
| "chunk_token_num", | |||||
| "delimiter", | |||||
| "raptor", | |||||
| "graphrag", | |||||
| "layout_recognize", | |||||
| "task_page_size", | |||||
| "pages", | |||||
| "html4excel", | |||||
| "auto_keywords", | |||||
| "auto_questions", | |||||
| "tag_kb_ids", | |||||
| "topn_tags", | |||||
| "filename_embd_weight", | |||||
| ] | |||||
| ) | |||||
| for k in parser_config.keys(): | |||||
| assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}" | |||||
| assert isinstance(parser_config.get("chunk_token_num", 1), int), "chunk_token_num should be int" | |||||
| assert 1 <= parser_config.get("chunk_token_num", 1) < 100000000, "chunk_token_num should be in range from 1 to 100000000" | |||||
| assert isinstance(parser_config.get("task_page_size", 1), int), "task_page_size should be int" | |||||
| assert 1 <= parser_config.get("task_page_size", 1) < 100000000, "task_page_size should be in range from 1 to 100000000" | |||||
| assert isinstance(parser_config.get("auto_keywords", 1), int), "auto_keywords should be int" | |||||
| assert 0 <= parser_config.get("auto_keywords", 0) < 32, "auto_keywords should be in range from 0 to 32" | |||||
| assert isinstance(parser_config.get("auto_questions", 1), int), "auto_questions should be int" | |||||
| assert 0 <= parser_config.get("auto_questions", 0) < 10, "auto_questions should be in range from 0 to 10" | |||||
| assert isinstance(parser_config.get("topn_tags", 1), int), "topn_tags should be int" | |||||
| assert 0 <= parser_config.get("topn_tags", 0) < 10, "topn_tags should be in range from 0 to 10" | |||||
| assert isinstance(parser_config.get("html4excel", False), bool), "html4excel should be True or False" | |||||
| assert isinstance(parser_config.get("delimiter", ""), str), "delimiter should be str" | |||||
| def check_duplicate_ids(ids, id_type="item"): | def check_duplicate_ids(ids, id_type="item"): | ||||
| """ | """ | ||||
| Check for duplicate IDs in a list and return unique IDs and error messages. | Check for duplicate IDs in a list and return unique IDs and error messages. | ||||
| def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: | def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: | ||||
| """Verifies availability of an embedding model for a specific tenant. | |||||
| """ | |||||
| Verifies availability of an embedding model for a specific tenant. | |||||
| Implements a four-stage validation process: | Implements a four-stage validation process: | ||||
| 1. Model identifier parsing and validation | 1. Model identifier parsing and validation | ||||
| return False, get_error_data_result(message="Database operation failed") | return False, get_error_data_result(message="Database operation failed") | ||||
| return True, None | return True, None | ||||
| def deep_merge(default: dict, custom: dict) -> dict: | |||||
| """ | |||||
| Recursively merges two dictionaries with priority given to `custom` values. | |||||
| Creates a deep copy of the `default` dictionary and iteratively merges nested | |||||
| dictionaries using a stack-based approach. Non-dict values in `custom` will | |||||
| completely override corresponding entries in `default`. | |||||
| Args: | |||||
| default (dict): Base dictionary containing default values. | |||||
| custom (dict): Dictionary containing overriding values. | |||||
| Returns: | |||||
| dict: New merged dictionary combining values from both inputs. | |||||
| Example: | |||||
| >>> from copy import deepcopy | |||||
| >>> default = {"a": 1, "nested": {"x": 10, "y": 20}} | |||||
| >>> custom = {"b": 2, "nested": {"y": 99, "z": 30}} | |||||
| >>> deep_merge(default, custom) | |||||
| {'a': 1, 'b': 2, 'nested': {'x': 10, 'y': 99, 'z': 30}} | |||||
| >>> deep_merge({"config": {"mode": "auto"}}, {"config": "manual"}) | |||||
| {'config': 'manual'} | |||||
| Notes: | |||||
| 1. Merge priority is always given to `custom` values at all nesting levels | |||||
| 2. Non-dict values (e.g. list, str) in `custom` will replace entire values | |||||
| in `default`, even if the original value was a dictionary | |||||
| 3. Time complexity: O(N) where N is total key-value pairs in `custom` | |||||
| 4. Recommended for configuration merging and nested data updates | |||||
| """ | |||||
| merged = deepcopy(default) | |||||
| stack = [(merged, custom)] | |||||
| while stack: | |||||
| base_dict, override_dict = stack.pop() | |||||
| for key, val in override_dict.items(): | |||||
| if key in base_dict and isinstance(val, dict) and isinstance(base_dict[key], dict): | |||||
| stack.append((base_dict[key], val)) | |||||
| else: | |||||
| base_dict[key] = val | |||||
| return merged |
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import uuid | |||||
| from enum import auto | from enum import auto | ||||
| from typing import Annotated, Any | from typing import Annotated, Any | ||||
| from flask import Request | from flask import Request | ||||
| from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator | |||||
| from pydantic import UUID1, BaseModel, Field, StringConstraints, ValidationError, field_serializer, field_validator | |||||
| from strenum import StrEnum | from strenum import StrEnum | ||||
| from werkzeug.exceptions import BadRequest, UnsupportedMediaType | from werkzeug.exceptions import BadRequest, UnsupportedMediaType | ||||
| from api.constants import DATASET_NAME_LIMIT | from api.constants import DATASET_NAME_LIMIT | ||||
| def validate_and_parse_json_request(request: Request, validator: type[BaseModel]) -> tuple[dict[str, Any] | None, str | None]: | |||||
| """Validates and parses JSON requests through a multi-stage validation pipeline. | |||||
| def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]: | |||||
| """ | |||||
| Validates and parses JSON requests through a multi-stage validation pipeline. | |||||
| Implements a robust four-stage validation process: | |||||
| Implements a four-stage validation process: | |||||
| 1. Content-Type verification (must be application/json) | 1. Content-Type verification (must be application/json) | ||||
| 2. JSON syntax validation | 2. JSON syntax validation | ||||
| 3. Payload structure type checking | 3. Payload structure type checking | ||||
| Args: | Args: | ||||
| request (Request): Flask request object containing HTTP payload | request (Request): Flask request object containing HTTP payload | ||||
| validator (type[BaseModel]): Pydantic model class for data validation | |||||
| extras (dict[str, Any] | None): Additional fields to merge into payload | |||||
| before validation. These fields will be removed from the final output | |||||
| exclude_unset (bool): Whether to exclude fields that have not been explicitly set | |||||
| Returns: | Returns: | ||||
| tuple[Dict[str, Any] | None, str | None]: | tuple[Dict[str, Any] | None, str | None]: | ||||
| - Diagnostic error message on failure | - Diagnostic error message on failure | ||||
| Raises: | Raises: | ||||
| UnsupportedMediaType: When Content-Type ≠ application/json | |||||
| UnsupportedMediaType: When Content-Type header is not application/json | |||||
| BadRequest: For structural JSON syntax errors | BadRequest: For structural JSON syntax errors | ||||
| ValidationError: When payload violates Pydantic schema rules | ValidationError: When payload violates Pydantic schema rules | ||||
| Examples: | Examples: | ||||
| Successful validation: | |||||
| ```python | |||||
| # Input: {"name": "Dataset1", "format": "csv"} | |||||
| # Returns: ({"name": "Dataset1", "format": "csv"}, None) | |||||
| ``` | |||||
| Invalid Content-Type: | |||||
| ```python | |||||
| # Returns: (None, "Unsupported content type: Expected application/json, got text/xml") | |||||
| ``` | |||||
| Malformed JSON: | |||||
| ```python | |||||
| # Returns: (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding") | |||||
| ``` | |||||
| >>> validate_and_parse_json_request(valid_request, DatasetSchema) | |||||
| ({"name": "Dataset1", "format": "csv"}, None) | |||||
| >>> validate_and_parse_json_request(xml_request, DatasetSchema) | |||||
| (None, "Unsupported content type: Expected application/json, got text/xml") | |||||
| >>> validate_and_parse_json_request(bad_json_request, DatasetSchema) | |||||
| (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding") | |||||
| Notes: | |||||
| 1. Validation Priority: | |||||
| - Content-Type verification precedes JSON parsing | |||||
| - Structural validation occurs before schema validation | |||||
| 2. Extra fields added via `extras` parameter are automatically removed | |||||
| from the final output after validation | |||||
| """ | """ | ||||
| try: | try: | ||||
| payload = request.get_json() or {} | payload = request.get_json() or {} | ||||
| return None, f"Invalid request payload: expected object, got {type(payload).__name__}" | return None, f"Invalid request payload: expected object, got {type(payload).__name__}" | ||||
| try: | try: | ||||
| if extras is not None: | |||||
| payload.update(extras) | |||||
| validated_request = validator(**payload) | validated_request = validator(**payload) | ||||
| except ValidationError as e: | except ValidationError as e: | ||||
| return None, format_validation_error_message(e) | return None, format_validation_error_message(e) | ||||
| parsed_payload = validated_request.model_dump(by_alias=True) | |||||
| parsed_payload = validated_request.model_dump(by_alias=True, exclude_unset=exclude_unset) | |||||
| if extras is not None: | |||||
| for key in list(parsed_payload.keys()): | |||||
| if key in extras: | |||||
| del parsed_payload[key] | |||||
| return parsed_payload, None | return parsed_payload, None | ||||
| def format_validation_error_message(e: ValidationError) -> str: | def format_validation_error_message(e: ValidationError) -> str: | ||||
| """Formats validation errors into a standardized string format. | |||||
| """ | |||||
| Formats validation errors into a standardized string format. | |||||
| Processes pydantic ValidationError objects to create human-readable error messages | Processes pydantic ValidationError objects to create human-readable error messages | ||||
| containing field locations, error descriptions, and input values. | containing field locations, error descriptions, and input values. | ||||
| class Base(BaseModel): | class Base(BaseModel): | ||||
| class Config: | class Config: | ||||
| extra = "forbid" | extra = "forbid" | ||||
| json_schema_extra = {"charset": "utf8mb4", "collation": "utf8mb4_0900_ai_ci"} | |||||
| class RaptorConfig(Base): | class RaptorConfig(Base): | ||||
| name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] | name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] | ||||
| avatar: str | None = Field(default=None, max_length=65535) | avatar: str | None = Field(default=None, max_length=65535) | ||||
| description: str | None = Field(default=None, max_length=65535) | description: str | None = Field(default=None, max_length=65535) | ||||
| embedding_model: Annotated[str | None, StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")] | |||||
| embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] | |||||
| permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)] | permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)] | ||||
| chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")] | chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")] | ||||
| pagerank: int = Field(default=0, ge=0, le=100) | pagerank: int = Field(default=0, ge=0, le=100) | ||||
| parser_config: ParserConfig | None = Field(default=None) | |||||
| parser_config: ParserConfig = Field(default_factory=dict) | |||||
| @field_validator("avatar") | @field_validator("avatar") | ||||
| @classmethod | @classmethod | ||||
| def validate_avatar_base64(cls, v: str) -> str: | |||||
| """Validates Base64-encoded avatar string format and MIME type compliance. | |||||
| def validate_avatar_base64(cls, v: str | None) -> str | None: | |||||
| """ | |||||
| Validates Base64-encoded avatar string format and MIME type compliance. | |||||
| Implements a three-stage validation workflow: | Implements a three-stage validation workflow: | ||||
| 1. MIME prefix existence check | 1. MIME prefix existence check | ||||
| @field_validator("embedding_model", mode="after") | @field_validator("embedding_model", mode="after") | ||||
| @classmethod | @classmethod | ||||
| def validate_embedding_model(cls, v: str) -> str: | def validate_embedding_model(cls, v: str) -> str: | ||||
| """Validates embedding model identifier format compliance. | |||||
| """ | |||||
| Validates embedding model identifier format compliance. | |||||
| Validation pipeline: | Validation pipeline: | ||||
| 1. Structural format verification | 1. Structural format verification | ||||
| @field_validator("permission", mode="before") | @field_validator("permission", mode="before") | ||||
| @classmethod | @classmethod | ||||
| def permission_auto_lowercase(cls, v: str) -> str: | |||||
| """Normalize permission input to lowercase for consistent PermissionEnum matching. | |||||
| def permission_auto_lowercase(cls, v: Any) -> Any: | |||||
| """ | |||||
| Normalize permission input to lowercase for consistent PermissionEnum matching. | |||||
| Args: | Args: | ||||
| v (str): Raw input value for the permission field | |||||
| v (Any): Raw input value for the permission field | |||||
| Returns: | Returns: | ||||
| Lowercase string if input is string type, otherwise returns original value | Lowercase string if input is string type, otherwise returns original value | ||||
| @field_validator("parser_config", mode="after") | @field_validator("parser_config", mode="after") | ||||
| @classmethod | @classmethod | ||||
| def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None: | |||||
| """Validates serialized JSON length constraints for parser configuration. | |||||
| def validate_parser_config_json_length(cls, v: ParserConfig) -> ParserConfig: | |||||
| """ | |||||
| Validates serialized JSON length constraints for parser configuration. | |||||
| Implements a three-stage validation workflow: | |||||
| 1. Null check - bypass validation for empty configurations | |||||
| 2. Model serialization - convert Pydantic model to JSON string | |||||
| 3. Size verification - enforce maximum allowed payload size | |||||
| Implements a two-stage validation workflow: | |||||
| 1. Model serialization - convert Pydantic model to JSON string | |||||
| 2. Size verification - enforce maximum allowed payload size | |||||
| Args: | Args: | ||||
| v (ParserConfig | None): Raw parser configuration object | v (ParserConfig | None): Raw parser configuration object | ||||
| Raises: | Raises: | ||||
| ValueError: When serialized JSON exceeds 65,535 characters | ValueError: When serialized JSON exceeds 65,535 characters | ||||
| """ | """ | ||||
| if v is None: | |||||
| return v | |||||
| if (json_str := v.model_dump_json()) and len(json_str) > 65535: | if (json_str := v.model_dump_json()) and len(json_str) > 65535: | ||||
| raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}") | raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}") | ||||
| return v | return v | ||||
| class UpdateDatasetReq(CreateDatasetReq): | |||||
| dataset_id: UUID1 = Field(...) | |||||
| name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] | |||||
| @field_serializer("dataset_id") | |||||
| def serialize_uuid_to_hex(self, v: uuid.UUID) -> str: | |||||
| return v.hex |
| - `"team"`: All team members can manage the dataset. | - `"team"`: All team members can manage the dataset. | ||||
| - `"pagerank"`: (*Body parameter*), `int` | - `"pagerank"`: (*Body parameter*), `int` | ||||
| Set page rank: refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) | |||||
| refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) | |||||
| - Default: `0` | - Default: `0` | ||||
| - Minimum: `0` | - Minimum: `0` | ||||
| - Maximum: `100` | - Maximum: `100` | ||||
| - `'Authorization: Bearer <YOUR_API_KEY>'` | - `'Authorization: Bearer <YOUR_API_KEY>'` | ||||
| - Body: | - Body: | ||||
| - `"name"`: `string` | - `"name"`: `string` | ||||
| - `"avatar"`: `string` | |||||
| - `"description"`: `string` | |||||
| - `"embedding_model"`: `string` | - `"embedding_model"`: `string` | ||||
| - `"chunk_method"`: `enum<string>` | |||||
| - `"permission"`: `string` | |||||
| - `"chunk_method"`: `string` | |||||
| - `"pagerank"`: `int` | |||||
| - `"parser_config"`: `object` | |||||
| ##### Request example | ##### Request example | ||||
| The ID of the dataset to update. | The ID of the dataset to update. | ||||
| - `"name"`: (*Body parameter*), `string` | - `"name"`: (*Body parameter*), `string` | ||||
| The revised name of the dataset. | The revised name of the dataset. | ||||
| - Basic Multilingual Plane (BMP) only | |||||
| - Maximum 128 characters | |||||
| - Case-insensitive | |||||
| - `"avatar"`: (*Body parameter*), `string` | |||||
| The updated base64 encoding of the avatar. | |||||
| - Maximum 65535 characters | |||||
| - `"embedding_model"`: (*Body parameter*), `string` | - `"embedding_model"`: (*Body parameter*), `string` | ||||
| The updated embedding model name. | The updated embedding model name. | ||||
| - Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`. | - Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`. | ||||
| - Maximum 255 characters | |||||
| - Must follow `model_name@model_factory` format | |||||
| - `"permission"`: (*Body parameter*), `string` | |||||
| The updated dataset permission. Available options: | |||||
| - `"me"`: (Default) Only you can manage the dataset. | |||||
| - `"team"`: All team members can manage the dataset. | |||||
| - `"pagerank"`: (*Body parameter*), `int` | |||||
| refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) | |||||
| - Default: `0` | |||||
| - Minimum: `0` | |||||
| - Maximum: `100` | |||||
| - `"chunk_method"`: (*Body parameter*), `enum<string>` | - `"chunk_method"`: (*Body parameter*), `enum<string>` | ||||
| The chunking method for the dataset. Available options: | The chunking method for the dataset. Available options: | ||||
| - `"naive"`: General | |||||
| - `"manual`: Manual | |||||
| - `"qa"`: Q&A | |||||
| - `"table"`: Table | |||||
| - `"paper"`: Paper | |||||
| - `"naive"`: General (default) | |||||
| - `"book"`: Book | - `"book"`: Book | ||||
| - `"email"`: Email | |||||
| - `"laws"`: Laws | - `"laws"`: Laws | ||||
| - `"presentation"`: Presentation | |||||
| - `"manual"`: Manual | |||||
| - `"one"`: One | |||||
| - `"paper"`: Paper | |||||
| - `"picture"`: Picture | - `"picture"`: Picture | ||||
| - `"one"`:One | |||||
| - `"email"`: Email | |||||
| - `"presentation"`: Presentation | |||||
| - `"qa"`: Q&A | |||||
| - `"table"`: Table | |||||
| - `"tag"`: Tag | |||||
| - `"parser_config"`: (*Body parameter*), `object` | |||||
| The configuration settings for the dataset parser. The attributes in this JSON object vary with the selected `"chunk_method"`: | |||||
| - If `"chunk_method"` is `"naive"`, the `"parser_config"` object contains the following attributes: | |||||
| - `"auto_keywords"`: `int` | |||||
| - Defaults to `0` | |||||
| - Minimum: `0` | |||||
| - Maximum: `32` | |||||
| - `"auto_questions"`: `int` | |||||
| - Defaults to `0` | |||||
| - Minimum: `0` | |||||
| - Maximum: `10` | |||||
| - `"chunk_token_num"`: `int` | |||||
| - Defaults to `128` | |||||
| - Minimum: `1` | |||||
| - Maximum: `2048` | |||||
| - `"delimiter"`: `string` | |||||
| - Defaults to `"\n"`. | |||||
| - `"html4excel"`: `bool` Indicates whether to convert Excel documents into HTML format. | |||||
| - Defaults to `false` | |||||
| - `"layout_recognize"`: `string` | |||||
| - Defaults to `DeepDOC` | |||||
| - `"tag_kb_ids"`: `array<string>` refer to [Use tag set](https://ragflow.io/docs/dev/use_tag_sets) | |||||
| - Must include a list of dataset IDs, where each dataset is parsed using the Tag Chunk Method | |||||
| - `"task_page_size"`: `int` For PDF only. | |||||
| - Defaults to `12` | |||||
| - Minimum: `1` | |||||
| - `"raptor"`: `object` RAPTOR-specific settings. | |||||
| - Defaults to: `{"use_raptor": false}` | |||||
| - `"graphrag"`: `object` GRAPHRAG-specific settings. | |||||
| - Defaults to: `{"use_graphrag": false}` | |||||
| - If `"chunk_method"` is `"qa"`, `"manuel"`, `"paper"`, `"book"`, `"laws"`, or `"presentation"`, the `"parser_config"` object contains the following attribute: | |||||
| - `"raptor"`: `object` RAPTOR-specific settings. | |||||
| - Defaults to: `{"use_raptor": false}`. | |||||
| - If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object. | |||||
| #### Response | #### Response | ||||
| A dictionary representing the attributes to update, with the following keys: | A dictionary representing the attributes to update, with the following keys: | ||||
| - `"name"`: `str` The revised name of the dataset. | - `"name"`: `str` The revised name of the dataset. | ||||
| - `"embedding_model"`: `str` The updated embedding model name. | |||||
| - Basic Multilingual Plane (BMP) only | |||||
| - Maximum 128 characters | |||||
| - Case-insensitive | |||||
| - `"avatar"`: (*Body parameter*), `string` | |||||
| The updated base64 encoding of the avatar. | |||||
| - Maximum 65535 characters | |||||
| - `"embedding_model"`: (*Body parameter*), `string` | |||||
| The updated embedding model name. | |||||
| - Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`. | - Ensure that `"chunk_count"` is `0` before updating `"embedding_model"`. | ||||
| - `"chunk_method"`: `str` The chunking method for the dataset. Available options: | |||||
| - `"naive"`: General | |||||
| - `"manual`: Manual | |||||
| - `"qa"`: Q&A | |||||
| - `"table"`: Table | |||||
| - `"paper"`: Paper | |||||
| - Maximum 255 characters | |||||
| - Must follow `model_name@model_factory` format | |||||
| - `"permission"`: (*Body parameter*), `string` | |||||
| The updated dataset permission. Available options: | |||||
| - `"me"`: (Default) Only you can manage the dataset. | |||||
| - `"team"`: All team members can manage the dataset. | |||||
| - `"pagerank"`: (*Body parameter*), `int` | |||||
| refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) | |||||
| - Default: `0` | |||||
| - Minimum: `0` | |||||
| - Maximum: `100` | |||||
| - `"chunk_method"`: (*Body parameter*), `enum<string>` | |||||
| The chunking method for the dataset. Available options: | |||||
| - `"naive"`: General (default) | |||||
| - `"book"`: Book | - `"book"`: Book | ||||
| - `"email"`: Email | |||||
| - `"laws"`: Laws | - `"laws"`: Laws | ||||
| - `"presentation"`: Presentation | |||||
| - `"picture"`: Picture | |||||
| - `"manual"`: Manual | |||||
| - `"one"`: One | - `"one"`: One | ||||
| - `"email"`: Email | |||||
| - `"paper"`: Paper | |||||
| - `"picture"`: Picture | |||||
| - `"presentation"`: Presentation | |||||
| - `"qa"`: Q&A | |||||
| - `"table"`: Table | |||||
| - `"tag"`: Tag | |||||
| #### Returns | #### Returns | ||||
| pagerank: int = 0, | pagerank: int = 0, | ||||
| parser_config: DataSet.ParserConfig = None, | parser_config: DataSet.ParserConfig = None, | ||||
| ) -> DataSet: | ) -> DataSet: | ||||
| if parser_config: | |||||
| parser_config = parser_config.to_json() | |||||
| res = self.post( | |||||
| "/datasets", | |||||
| { | |||||
| "name": name, | |||||
| "avatar": avatar, | |||||
| "description": description, | |||||
| "embedding_model": embedding_model, | |||||
| "permission": permission, | |||||
| "chunk_method": chunk_method, | |||||
| "pagerank": pagerank, | |||||
| "parser_config": parser_config, | |||||
| }, | |||||
| ) | |||||
| payload = { | |||||
| "name": name, | |||||
| "avatar": avatar, | |||||
| "description": description, | |||||
| "embedding_model": embedding_model, | |||||
| "permission": permission, | |||||
| "chunk_method": chunk_method, | |||||
| "pagerank": pagerank, | |||||
| } | |||||
| if parser_config is not None: | |||||
| payload["parser_config"] = parser_config.to_json() | |||||
| res = self.post("/datasets", payload) | |||||
| res = res.json() | res = res.json() | ||||
| if res.get("code") == 0: | if res.get("code") == 0: | ||||
| return DataSet(self, res["data"]) | return DataSet(self, res["data"]) |
| # | |||||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import hypothesis.strategies as st | |||||
| @st.composite | |||||
| def valid_names(draw): | |||||
| base_chars = "abcdefghijklmnopqrstuvwxyz_" | |||||
| first_char = draw(st.sampled_from([c for c in base_chars if c.isalpha() or c == "_"])) | |||||
| remaining = draw(st.text(alphabet=st.sampled_from(base_chars), min_size=0, max_size=128 - 2)) | |||||
| name = (first_char + remaining)[:128] | |||||
| return name.encode("utf-8").decode("utf-8") |
| # DATASET MANAGEMENT | # DATASET MANAGEMENT | ||||
| def create_dataset(auth, payload=None, headers=HEADERS, data=None): | |||||
| def create_dataset(auth, payload=None, *, headers=HEADERS, data=None): | |||||
| res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data) | res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data) | ||||
| return res.json() | return res.json() | ||||
| def list_datasets(auth, params=None, headers=HEADERS): | |||||
| def list_datasets(auth, params=None, *, headers=HEADERS): | |||||
| res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params) | res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params) | ||||
| return res.json() | return res.json() | ||||
| def update_dataset(auth, dataset_id, payload=None, headers=HEADERS): | |||||
| res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload) | |||||
| def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): | |||||
| res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload, data=data) | |||||
| return res.json() | return res.json() | ||||
| def delete_datasets(auth, payload=None, headers=HEADERS): | |||||
| res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload) | |||||
| def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None): | |||||
| res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data) | |||||
| return res.json() | return res.json() | ||||
| request.addfinalizer(cleanup) | request.addfinalizer(cleanup) | ||||
| return batch_create_datasets(get_http_api_auth, 3) | return batch_create_datasets(get_http_api_auth, 3) | ||||
| @pytest.fixture(scope="function") | |||||
| def add_dataset_func(get_http_api_auth, request): | |||||
| def cleanup(): | |||||
| delete_datasets(get_http_api_auth) | |||||
| request.addfinalizer(cleanup) | |||||
| return batch_create_datasets(get_http_api_auth, 1)[0] |
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| from concurrent.futures import ThreadPoolExecutor | |||||
| import hypothesis.strategies as st | |||||
| import pytest | import pytest | ||||
| from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, create_dataset | from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, create_dataset | ||||
| from hypothesis import example, given, settings | from hypothesis import example, given, settings | ||||
| from libs.auth import RAGFlowHttpApiAuth | from libs.auth import RAGFlowHttpApiAuth | ||||
| from libs.utils import encode_avatar | from libs.utils import encode_avatar | ||||
| from libs.utils.file_utils import create_image_file | from libs.utils.file_utils import create_image_file | ||||
| from libs.utils.hypothesis_utils import valid_names | |||||
| @st.composite | |||||
| def valid_names(draw): | |||||
| base_chars = "abcdefghijklmnopqrstuvwxyz_" | |||||
| first_char = draw(st.sampled_from([c for c in base_chars if c.isalpha() or c == "_"])) | |||||
| remaining = draw(st.text(alphabet=st.sampled_from(base_chars), min_size=0, max_size=DATASET_NAME_LIMIT - 2)) | |||||
| name = (first_char + remaining)[:128] | |||||
| return name.encode("utf-8").decode("utf-8") | |||||
| @pytest.mark.p1 | |||||
| @pytest.mark.usefixtures("clear_datasets") | @pytest.mark.usefixtures("clear_datasets") | ||||
| class TestAuthorization: | class TestAuthorization: | ||||
| @pytest.mark.p1 | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "auth, expected_code, expected_message", | "auth, expected_code, expected_message", | ||||
| [ | [ | ||||
| ], | ], | ||||
| ids=["empty_auth", "invalid_api_token"], | ids=["empty_auth", "invalid_api_token"], | ||||
| ) | ) | ||||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||||
| def test_auth_invalid(self, auth, expected_code, expected_message): | |||||
| res = create_dataset(auth, {"name": "auth_test"}) | res = create_dataset(auth, {"name": "auth_test"}) | ||||
| assert res["code"] == expected_code | |||||
| assert res["message"] == expected_message | |||||
| assert res["code"] == expected_code, res | |||||
| assert res["message"] == expected_message, res | |||||
| class TestRquest: | |||||
| @pytest.mark.p3 | |||||
| def test_content_type_bad(self, get_http_api_auth): | |||||
| BAD_CONTENT_TYPE = "text/xml" | |||||
| res = create_dataset(get_http_api_auth, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE}) | |||||
| assert res["code"] == 101, res | |||||
| assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_message", | |||||
| [ | |||||
| ("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"), | |||||
| ('"a"', "Invalid request payload: expected object, got str"), | |||||
| ], | |||||
| ids=["malformed_json_syntax", "invalid_request_payload_type"], | |||||
| ) | |||||
| def test_payload_bad(self, get_http_api_auth, payload, expected_message): | |||||
| res = create_dataset(get_http_api_auth, data=payload) | |||||
| assert res["code"] == 101, res | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.usefixtures("clear_datasets") | @pytest.mark.usefixtures("clear_datasets") | ||||
| class TestDatasetCreation: | |||||
| class TestCapability: | |||||
| @pytest.mark.p3 | |||||
| def test_create_dataset_1k(self, get_http_api_auth): | |||||
| for i in range(1_000): | |||||
| payload = {"name": f"dataset_{i}"} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, f"Failed to create dataset {i}" | |||||
| @pytest.mark.p3 | |||||
| def test_create_dataset_concurrent(self, get_http_api_auth): | |||||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||||
| futures = [executor.submit(create_dataset, get_http_api_auth, {"name": f"dataset_{i}"}) for i in range(100)] | |||||
| responses = [f.result() for f in futures] | |||||
| assert all(r["code"] == 0 for r in responses), responses | |||||
| @pytest.mark.usefixtures("clear_datasets") | |||||
| class TestDatasetCreate: | |||||
| @pytest.mark.p1 | @pytest.mark.p1 | ||||
| @given(name=valid_names()) | @given(name=valid_names()) | ||||
| @example("a" * 128) | @example("a" * 128) | ||||
| @settings(max_examples=20) | @settings(max_examples=20) | ||||
| def test_valid_name(self, get_http_api_auth, name): | |||||
| def test_name(self, get_http_api_auth, name): | |||||
| res = create_dataset(get_http_api_auth, {"name": name}) | res = create_dataset(get_http_api_auth, {"name": name}) | ||||
| assert res["code"] == 0, res | assert res["code"] == 0, res | ||||
| assert res["data"]["name"] == name, res | assert res["data"]["name"] == name, res | ||||
| @pytest.mark.p1 | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, expected_message", | "name, expected_message", | ||||
| [ | [ | ||||
| (" ", "String should have at least 1 character"), | (" ", "String should have at least 1 character"), | ||||
| ("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"), | ("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"), | ||||
| (0, "Input should be a valid string"), | (0, "Input should be a valid string"), | ||||
| (None, "Input should be a valid string"), | |||||
| ], | ], | ||||
| ids=["empty_name", "space_name", "too_long_name", "invalid_name"], | |||||
| ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], | |||||
| ) | ) | ||||
| def test_invalid_name(self, get_http_api_auth, name, expected_message): | |||||
| res = create_dataset(get_http_api_auth, {"name": name}) | |||||
| def test_name_invalid(self, get_http_api_auth, name, expected_message): | |||||
| payload = {"name": name} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 101, res | assert res["code"] == 101, res | ||||
| assert expected_message in res["message"], res | assert expected_message in res["message"], res | ||||
| @pytest.mark.p2 | |||||
| def test_duplicated_name(self, get_http_api_auth): | |||||
| @pytest.mark.p3 | |||||
| def test_name_duplicated(self, get_http_api_auth): | |||||
| name = "duplicated_name" | name = "duplicated_name" | ||||
| payload = {"name": name} | payload = {"name": name} | ||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 0, res | assert res["code"] == 0, res | ||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 101, res | |||||
| assert res["code"] == 102, res | |||||
| assert res["message"] == f"Dataset name '{name}' already exists", res | assert res["message"] == f"Dataset name '{name}' already exists", res | ||||
| @pytest.mark.p2 | |||||
| def test_case_insensitive(self, get_http_api_auth): | |||||
| @pytest.mark.p3 | |||||
| def test_name_case_insensitive(self, get_http_api_auth): | |||||
| name = "CaseInsensitive" | name = "CaseInsensitive" | ||||
| res = create_dataset(get_http_api_auth, {"name": name.upper()}) | |||||
| payload = {"name": name.upper()} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | assert res["code"] == 0, res | ||||
| res = create_dataset(get_http_api_auth, {"name": name.lower()}) | |||||
| assert res["code"] == 101, res | |||||
| payload = {"name": name.lower()} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 102, res | |||||
| assert res["message"] == f"Dataset name '{name.lower()}' already exists", res | assert res["message"] == f"Dataset name '{name.lower()}' already exists", res | ||||
| @pytest.mark.p3 | |||||
| def test_bad_content_type(self, get_http_api_auth): | |||||
| BAD_CONTENT_TYPE = "text/xml" | |||||
| res = create_dataset(get_http_api_auth, {"name": "name"}, {"Content-Type": BAD_CONTENT_TYPE}) | |||||
| assert res["code"] == 101, res | |||||
| assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "payload, expected_message", | |||||
| [ | |||||
| ("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"), | |||||
| ('"a"', "Invalid request payload: expected objec"), | |||||
| ], | |||||
| ids=["malformed_json_syntax", "invalid_request_payload_type"], | |||||
| ) | |||||
| def test_bad_payload(self, get_http_api_auth, payload, expected_message): | |||||
| res = create_dataset(get_http_api_auth, data=payload) | |||||
| assert res["code"] == 101, res | |||||
| assert expected_message in res["message"], res | |||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| def test_avatar(self, get_http_api_auth, tmp_path): | def test_avatar(self, get_http_api_auth, tmp_path): | ||||
| fn = create_image_file(tmp_path / "ragflow_test.png") | fn = create_image_file(tmp_path / "ragflow_test.png") | ||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 0, res | assert res["code"] == 0, res | ||||
| @pytest.mark.p3 | |||||
| def test_avatar_none(self, get_http_api_auth, tmp_path): | |||||
| payload = {"name": "test_avatar_none", "avatar": None} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["avatar"] is None, res | |||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| def test_avatar_exceeds_limit_length(self, get_http_api_auth): | def test_avatar_exceeds_limit_length(self, get_http_api_auth): | ||||
| res = create_dataset(get_http_api_auth, {"name": "exceeds_limit_length_avatar", "avatar": "a" * 65536}) | |||||
| payload = {"name": "exceeds_limit_length_avatar", "avatar": "a" * 65536} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 101, res | assert res["code"] == 101, res | ||||
| assert "String should have at most 65535 characters" in res["message"], res | assert "String should have at most 65535 characters" in res["message"], res | ||||
| ], | ], | ||||
| ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"], | ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"], | ||||
| ) | ) | ||||
| def test_invalid_avatar_prefix(self, get_http_api_auth, tmp_path, name, avatar_prefix, expected_message): | |||||
| def test_avatar_invalid_prefix(self, get_http_api_auth, tmp_path, name, avatar_prefix, expected_message): | |||||
| fn = create_image_file(tmp_path / "ragflow_test.png") | fn = create_image_file(tmp_path / "ragflow_test.png") | ||||
| payload = { | payload = { | ||||
| "name": name, | "name": name, | ||||
| assert expected_message in res["message"], res | assert expected_message in res["message"], res | ||||
| @pytest.mark.p3 | @pytest.mark.p3 | ||||
| def test_description_none(self, get_http_api_auth): | |||||
| payload = {"name": "test_description_none", "description": None} | |||||
| def test_avatar_unset(self, get_http_api_auth): | |||||
| payload = {"name": "test_avatar_unset"} | |||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 0, res | assert res["code"] == 0, res | ||||
| assert res["data"]["description"] is None, res | |||||
| assert res["data"]["avatar"] is None, res | |||||
| @pytest.mark.p3 | |||||
| def test_avatar_none(self, get_http_api_auth): | |||||
| payload = {"name": "test_avatar_none", "avatar": None} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["avatar"] is None, res | |||||
| @pytest.mark.p2 | |||||
| def test_description(self, get_http_api_auth): | |||||
| payload = {"name": "test_description", "description": "description"} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["description"] == "description", res | |||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| def test_description_exceeds_limit_length(self, get_http_api_auth): | def test_description_exceeds_limit_length(self, get_http_api_auth): | ||||
| assert res["code"] == 101, res | assert res["code"] == 101, res | ||||
| assert "String should have at most 65535 characters" in res["message"], res | assert "String should have at most 65535 characters" in res["message"], res | ||||
| @pytest.mark.p3 | |||||
| def test_description_unset(self, get_http_api_auth): | |||||
| payload = {"name": "test_description_unset"} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["description"] is None, res | |||||
| @pytest.mark.p3 | |||||
| def test_description_none(self, get_http_api_auth): | |||||
| payload = {"name": "test_description_none", "description": None} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["description"] is None, res | |||||
| @pytest.mark.p1 | @pytest.mark.p1 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, embedding_model", | "name, embedding_model", | ||||
| ("BAAI/bge-large-zh-v1.5@BAAI", "BAAI/bge-large-zh-v1.5@BAAI"), | ("BAAI/bge-large-zh-v1.5@BAAI", "BAAI/bge-large-zh-v1.5@BAAI"), | ||||
| ("maidalun1020/bce-embedding-base_v1@Youdao", "maidalun1020/bce-embedding-base_v1@Youdao"), | ("maidalun1020/bce-embedding-base_v1@Youdao", "maidalun1020/bce-embedding-base_v1@Youdao"), | ||||
| ("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"), | ("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"), | ||||
| ("embedding_model_default", None), | |||||
| ], | ], | ||||
| ids=["builtin_baai", "builtin_youdao", "tenant_zhipu", "default"], | |||||
| ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"], | |||||
| ) | ) | ||||
| def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model): | |||||
| if embedding_model is None: | |||||
| payload = {"name": name} | |||||
| else: | |||||
| payload = {"name": name, "embedding_model": embedding_model} | |||||
| def test_embedding_model(self, get_http_api_auth, name, embedding_model): | |||||
| payload = {"name": name, "embedding_model": embedding_model} | |||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 0, res | assert res["code"] == 0, res | ||||
| if embedding_model is None: | |||||
| assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res | |||||
| else: | |||||
| assert res["data"]["embedding_model"] == embedding_model, res | |||||
| assert res["data"]["embedding_model"] == embedding_model, res | |||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| ], | ], | ||||
| ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], | ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], | ||||
| ) | ) | ||||
| def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model): | |||||
| def test_embedding_model_invalid(self, get_http_api_auth, name, embedding_model): | |||||
| payload = {"name": name, "embedding_model": embedding_model} | payload = {"name": name, "embedding_model": embedding_model} | ||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 101, res | assert res["code"] == 101, res | ||||
| else: | else: | ||||
| assert "Both model_name and provider must be non-empty strings" in res["message"], res | assert "Both model_name and provider must be non-empty strings" in res["message"], res | ||||
| @pytest.mark.p2 | |||||
| def test_embedding_model_unset(self, get_http_api_auth): | |||||
| payload = {"name": "embedding_model_unset"} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res | |||||
| @pytest.mark.p2 | |||||
| def test_embedding_model_none(self, get_http_api_auth): | |||||
| payload = {"name": "test_embedding_model_none", "embedding_model": None} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be a valid string" in res["message"], res | |||||
| @pytest.mark.p1 | @pytest.mark.p1 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, permission", | "name, permission", | ||||
| ("team", "team"), | ("team", "team"), | ||||
| ("me_upercase", "ME"), | ("me_upercase", "ME"), | ||||
| ("team_upercase", "TEAM"), | ("team_upercase", "TEAM"), | ||||
| ("permission_default", None), | |||||
| ], | ], | ||||
| ids=["me", "team", "me_upercase", "team_upercase", "permission_default"], | |||||
| ids=["me", "team", "me_upercase", "team_upercase"], | |||||
| ) | ) | ||||
| def test_valid_permission(self, get_http_api_auth, name, permission): | |||||
| if permission is None: | |||||
| payload = {"name": name} | |||||
| else: | |||||
| payload = {"name": name, "permission": permission} | |||||
| def test_permission(self, get_http_api_auth, name, permission): | |||||
| payload = {"name": name, "permission": permission} | |||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 0, res | assert res["code"] == 0, res | ||||
| if permission is None: | |||||
| assert res["data"]["permission"] == "me", res | |||||
| else: | |||||
| assert res["data"]["permission"] == permission.lower(), res | |||||
| assert res["data"]["permission"] == permission.lower(), res | |||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| ("unknown", "unknown"), | ("unknown", "unknown"), | ||||
| ("type_error", list()), | ("type_error", list()), | ||||
| ], | ], | ||||
| ids=["empty", "unknown", "type_error"], | |||||
| ) | ) | ||||
| def test_invalid_permission(self, get_http_api_auth, name, permission): | |||||
| def test_permission_invalid(self, get_http_api_auth, name, permission): | |||||
| payload = {"name": name, "permission": permission} | payload = {"name": name, "permission": permission} | ||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 101 | assert res["code"] == 101 | ||||
| assert "Input should be 'me' or 'team'" in res["message"] | assert "Input should be 'me' or 'team'" in res["message"] | ||||
| @pytest.mark.p2 | |||||
| def test_permission_unset(self, get_http_api_auth): | |||||
| payload = {"name": "test_permission_unset"} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["permission"] == "me", res | |||||
| @pytest.mark.p3 | |||||
| def test_permission_none(self, get_http_api_auth): | |||||
| payload = {"name": "test_permission_none", "permission": None} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be 'me' or 'team'" in res["message"], res | |||||
| @pytest.mark.p1 | @pytest.mark.p1 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, chunk_method", | "name, chunk_method", | ||||
| ("qa", "qa"), | ("qa", "qa"), | ||||
| ("table", "table"), | ("table", "table"), | ||||
| ("tag", "tag"), | ("tag", "tag"), | ||||
| ("chunk_method_default", None), | |||||
| ], | ], | ||||
| ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], | |||||
| ) | ) | ||||
| def test_valid_chunk_method(self, get_http_api_auth, name, chunk_method): | |||||
| if chunk_method is None: | |||||
| payload = {"name": name} | |||||
| else: | |||||
| payload = {"name": name, "chunk_method": chunk_method} | |||||
| def test_chunk_method(self, get_http_api_auth, name, chunk_method): | |||||
| payload = {"name": name, "chunk_method": chunk_method} | |||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 0, res | assert res["code"] == 0, res | ||||
| if chunk_method is None: | |||||
| assert res["data"]["chunk_method"] == "naive", res | |||||
| else: | |||||
| assert res["data"]["chunk_method"] == chunk_method, res | |||||
| assert res["data"]["chunk_method"] == chunk_method, res | |||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| ("unknown", "unknown"), | ("unknown", "unknown"), | ||||
| ("type_error", list()), | ("type_error", list()), | ||||
| ], | ], | ||||
| ids=["empty", "unknown", "type_error"], | |||||
| ) | ) | ||||
| def test_invalid_chunk_method(self, get_http_api_auth, name, chunk_method): | |||||
| def test_chunk_method_invalid(self, get_http_api_auth, name, chunk_method): | |||||
| payload = {"name": name, "chunk_method": chunk_method} | payload = {"name": name, "chunk_method": chunk_method} | ||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 101, res | assert res["code"] == 101, res | ||||
| assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res | assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res | ||||
| @pytest.mark.p2 | |||||
| def test_chunk_method_unset(self, get_http_api_auth): | |||||
| payload = {"name": "test_chunk_method_unset"} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["chunk_method"] == "naive", res | |||||
| @pytest.mark.p3 | |||||
| def test_chunk_method_none(self, get_http_api_auth): | |||||
| payload = {"name": "chunk_method_none", "chunk_method": None} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "name, pagerank", | |||||
| [ | |||||
| ("pagerank_min", 0), | |||||
| ("pagerank_mid", 50), | |||||
| ("pagerank_max", 100), | |||||
| ], | |||||
| ids=["min", "mid", "max"], | |||||
| ) | |||||
| def test_pagerank(self, get_http_api_auth, name, pagerank): | |||||
| payload = {"name": name, "pagerank": pagerank} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["pagerank"] == pagerank, res | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "name, pagerank, expected_message", | |||||
| [ | |||||
| ("pagerank_min_limit", -1, "Input should be greater than or equal to 0"), | |||||
| ("pagerank_max_limit", 101, "Input should be less than or equal to 100"), | |||||
| ], | |||||
| ids=["min_limit", "max_limit"], | |||||
| ) | |||||
| def test_pagerank_invalid(self, get_http_api_auth, name, pagerank, expected_message): | |||||
| payload = {"name": name, "pagerank": pagerank} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert expected_message in res["message"], res | |||||
| @pytest.mark.p3 | |||||
| def test_pagerank_unset(self, get_http_api_auth): | |||||
| payload = {"name": "pagerank_unset"} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["pagerank"] == 0, res | |||||
| @pytest.mark.p3 | |||||
| def test_pagerank_none(self, get_http_api_auth): | |||||
| payload = {"name": "pagerank_unset", "pagerank": None} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be a valid integer" in res["message"], res | |||||
| @pytest.mark.p1 | @pytest.mark.p1 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, parser_config", | "name, parser_config", | ||||
| [ | [ | ||||
| ("default_none", None), | |||||
| ("default_empty", {}), | |||||
| ("auto_keywords_min", {"auto_keywords": 0}), | ("auto_keywords_min", {"auto_keywords": 0}), | ||||
| ("auto_keywords_mid", {"auto_keywords": 16}), | ("auto_keywords_mid", {"auto_keywords": 16}), | ||||
| ("auto_keywords_max", {"auto_keywords": 32}), | ("auto_keywords_max", {"auto_keywords": 32}), | ||||
| ("task_page_size_min", {"task_page_size": 1}), | ("task_page_size_min", {"task_page_size": 1}), | ||||
| ("task_page_size_None", {"task_page_size": None}), | ("task_page_size_None", {"task_page_size": None}), | ||||
| ("pages", {"pages": [[1, 100]]}), | ("pages", {"pages": [[1, 100]]}), | ||||
| ("pages_none", None), | |||||
| ("pages_none", {"pages": None}), | |||||
| ("graphrag_true", {"graphrag": {"use_graphrag": True}}), | ("graphrag_true", {"graphrag": {"use_graphrag": True}}), | ||||
| ("graphrag_false", {"graphrag": {"use_graphrag": False}}), | ("graphrag_false", {"graphrag": {"use_graphrag": False}}), | ||||
| ("graphrag_entity_types", {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}), | ("graphrag_entity_types", {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}), | ||||
| ("raptor_random_seed_min", {"raptor": {"random_seed": 0}}), | ("raptor_random_seed_min", {"raptor": {"random_seed": 0}}), | ||||
| ], | ], | ||||
| ids=[ | ids=[ | ||||
| "default_none", | |||||
| "default_empty", | |||||
| "auto_keywords_min", | "auto_keywords_min", | ||||
| "auto_keywords_mid", | "auto_keywords_mid", | ||||
| "auto_keywords_max", | "auto_keywords_max", | ||||
| "raptor_random_seed_min", | "raptor_random_seed_min", | ||||
| ], | ], | ||||
| ) | ) | ||||
| def test_valid_parser_config(self, get_http_api_auth, name, parser_config): | |||||
| if parser_config is None: | |||||
| payload = {"name": name} | |||||
| else: | |||||
| payload = {"name": name, "parser_config": parser_config} | |||||
| def test_parser_config(self, get_http_api_auth, name, parser_config): | |||||
| payload = {"name": name, "parser_config": parser_config} | |||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 0, res | assert res["code"] == 0, res | ||||
| if parser_config is None: | |||||
| assert res["data"]["parser_config"] == { | |||||
| "chunk_token_num": 128, | |||||
| "delimiter": r"\n", | |||||
| "html4excel": False, | |||||
| "layout_recognize": "DeepDOC", | |||||
| "raptor": {"use_raptor": False}, | |||||
| } | |||||
| elif parser_config == {}: | |||||
| assert res["data"]["parser_config"] == { | |||||
| "auto_keywords": 0, | |||||
| "auto_questions": 0, | |||||
| "chunk_token_num": 128, | |||||
| "delimiter": r"\n", | |||||
| "filename_embd_weight": None, | |||||
| "graphrag": None, | |||||
| "html4excel": False, | |||||
| "layout_recognize": "DeepDOC", | |||||
| "pages": None, | |||||
| "raptor": None, | |||||
| "tag_kb_ids": [], | |||||
| "task_page_size": None, | |||||
| "topn_tags": 1, | |||||
| } | |||||
| else: | |||||
| for k, v in parser_config.items(): | |||||
| if isinstance(v, dict): | |||||
| for kk, vv in v.items(): | |||||
| assert res["data"]["parser_config"][k][kk] == vv | |||||
| else: | |||||
| assert res["data"]["parser_config"][k] == v | |||||
| for k, v in parser_config.items(): | |||||
| if isinstance(v, dict): | |||||
| for kk, vv in v.items(): | |||||
| assert res["data"]["parser_config"][k][kk] == vv, res | |||||
| else: | |||||
| assert res["data"]["parser_config"][k] == v, res | |||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "parser_config_type_invalid", | "parser_config_type_invalid", | ||||
| ], | ], | ||||
| ) | ) | ||||
| def test_invalid_parser_config(self, get_http_api_auth, name, parser_config, expected_message): | |||||
| def test_parser_config_invalid(self, get_http_api_auth, name, parser_config, expected_message): | |||||
| payload = {"name": name, "parser_config": parser_config} | payload = {"name": name, "parser_config": parser_config} | ||||
| res = create_dataset(get_http_api_auth, payload) | res = create_dataset(get_http_api_auth, payload) | ||||
| assert res["code"] == 101, res | assert res["code"] == 101, res | ||||
| assert expected_message in res["message"], res | assert expected_message in res["message"], res | ||||
| @pytest.mark.p2 | |||||
| def test_parser_config_empty(self, get_http_api_auth): | |||||
| payload = {"name": "default_empty", "parser_config": {}} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["parser_config"] == { | |||||
| "auto_keywords": 0, | |||||
| "auto_questions": 0, | |||||
| "chunk_token_num": 128, | |||||
| "delimiter": r"\n", | |||||
| "filename_embd_weight": None, | |||||
| "graphrag": None, | |||||
| "html4excel": False, | |||||
| "layout_recognize": "DeepDOC", | |||||
| "pages": None, | |||||
| "raptor": None, | |||||
| "tag_kb_ids": [], | |||||
| "task_page_size": None, | |||||
| "topn_tags": 1, | |||||
| } | |||||
| @pytest.mark.p2 | |||||
| def test_parser_config_unset(self, get_http_api_auth): | |||||
| payload = {"name": "default_unset"} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["parser_config"] == { | |||||
| "chunk_token_num": 128, | |||||
| "delimiter": r"\n", | |||||
| "html4excel": False, | |||||
| "layout_recognize": "DeepDOC", | |||||
| "raptor": {"use_raptor": False}, | |||||
| }, res | |||||
| @pytest.mark.p3 | @pytest.mark.p3 | ||||
| def test_dataset_10k(self, get_http_api_auth): | |||||
| for i in range(10_000): | |||||
| payload = {"name": f"dataset_{i}"} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 0, f"Failed to create dataset {i}" | |||||
| def test_parser_config_none(self, get_http_api_auth): | |||||
| payload = {"name": "default_none", "parser_config": None} | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be a valid dictionary or instance of ParserConfig" in res["message"], res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "payload", | |||||
| [ | |||||
| {"name": "id", "id": "id"}, | |||||
| {"name": "tenant_id", "tenant_id": "e57c1966f99211efb41e9e45646e0111"}, | |||||
| {"name": "created_by", "created_by": "created_by"}, | |||||
| {"name": "create_date", "create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, | |||||
| {"name": "create_time", "create_time": 1741671443322}, | |||||
| {"name": "update_date", "update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, | |||||
| {"name": "update_time", "update_time": 1741671443339}, | |||||
| {"name": "document_count", "document_count": 1}, | |||||
| {"name": "chunk_count", "chunk_count": 1}, | |||||
| {"name": "token_num", "token_num": 1}, | |||||
| {"name": "status", "status": "1"}, | |||||
| {"name": "unknown_field", "unknown_field": "unknown_field"}, | |||||
| ], | |||||
| ) | |||||
| def test_unsupported_field(self, get_http_api_auth, payload): | |||||
| res = create_dataset(get_http_api_auth, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Extra inputs are not permitted" in res["message"], res |
| from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||
| import pytest | import pytest | ||||
| from common import ( | |||||
| DATASET_NAME_LIMIT, | |||||
| INVALID_API_TOKEN, | |||||
| list_datasets, | |||||
| update_dataset, | |||||
| ) | |||||
| from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset | |||||
| from hypothesis import HealthCheck, example, given, settings | |||||
| from libs.auth import RAGFlowHttpApiAuth | from libs.auth import RAGFlowHttpApiAuth | ||||
| from libs.utils import encode_avatar | from libs.utils import encode_avatar | ||||
| from libs.utils.file_utils import create_image_file | from libs.utils.file_utils import create_image_file | ||||
| from libs.utils.hypothesis_utils import valid_names | |||||
| # TODO: Missing scenario for updating embedding_model with chunk_count != 0 | # TODO: Missing scenario for updating embedding_model with chunk_count != 0 | ||||
| @pytest.mark.p1 | |||||
| class TestAuthorization: | class TestAuthorization: | ||||
| @pytest.mark.p1 | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "auth, expected_code, expected_message", | "auth, expected_code, expected_message", | ||||
| [ | [ | ||||
| "Authentication error: API key is invalid!", | "Authentication error: API key is invalid!", | ||||
| ), | ), | ||||
| ], | ], | ||||
| ids=["empty_auth", "invalid_api_token"], | |||||
| ) | ) | ||||
| def test_invalid_auth(self, auth, expected_code, expected_message): | |||||
| def test_auth_invalid(self, auth, expected_code, expected_message): | |||||
| res = update_dataset(auth, "dataset_id") | res = update_dataset(auth, "dataset_id") | ||||
| assert res["code"] == expected_code | |||||
| assert res["message"] == expected_message | |||||
| assert res["code"] == expected_code, res | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p1 | |||||
| class TestDatasetUpdate: | |||||
| class TestRquest: | |||||
| @pytest.mark.p3 | |||||
| def test_bad_content_type(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| BAD_CONTENT_TYPE = "text/xml" | |||||
| res = update_dataset(get_http_api_auth, dataset_id, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE}) | |||||
| assert res["code"] == 101, res | |||||
| assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, expected_code, expected_message", | |||||
| "payload, expected_message", | |||||
| [ | [ | ||||
| ("valid_name", 0, ""), | |||||
| ( | |||||
| "a" * (DATASET_NAME_LIMIT + 1), | |||||
| 102, | |||||
| "Dataset name should not be longer than 128 characters.", | |||||
| ), | |||||
| (0, 100, """AttributeError("\'int\' object has no attribute \'strip\'")"""), | |||||
| ( | |||||
| None, | |||||
| 100, | |||||
| """AttributeError("\'NoneType\' object has no attribute \'strip\'")""", | |||||
| ), | |||||
| pytest.param("", 102, "", marks=pytest.mark.skip(reason="issue/5915")), | |||||
| ("dataset_1", 102, "Duplicated dataset name in updating dataset."), | |||||
| ("DATASET_1", 102, "Duplicated dataset name in updating dataset."), | |||||
| ("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"), | |||||
| ('"a"', "Invalid request payload: expected object, got str"), | |||||
| ], | ], | ||||
| ids=["malformed_json_syntax", "invalid_request_payload_type"], | |||||
| ) | ) | ||||
| def test_name(self, get_http_api_auth, add_datasets_func, name, expected_code, expected_message): | |||||
| dataset_ids = add_datasets_func | |||||
| res = update_dataset(get_http_api_auth, dataset_ids[0], {"name": name}) | |||||
| assert res["code"] == expected_code | |||||
| if expected_code == 0: | |||||
| res = list_datasets(get_http_api_auth, {"id": dataset_ids[0]}) | |||||
| assert res["data"][0]["name"] == name | |||||
| else: | |||||
| assert res["message"] == expected_message | |||||
| def test_payload_bad(self, get_http_api_auth, add_dataset_func, payload, expected_message): | |||||
| dataset_id = add_dataset_func | |||||
| res = update_dataset(get_http_api_auth, dataset_id, data=payload) | |||||
| assert res["code"] == 101, res | |||||
| assert res["message"] == expected_message, res | |||||
| @pytest.mark.p2 | |||||
| def test_payload_empty(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| res = update_dataset(get_http_api_auth, dataset_id, {}) | |||||
| assert res["code"] == 101, res | |||||
| assert res["message"] == "No properties were modified", res | |||||
| class TestCapability: | |||||
| @pytest.mark.p3 | |||||
| def test_update_dateset_concurrent(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||||
| futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)] | |||||
| responses = [f.result() for f in futures] | |||||
| assert all(r["code"] == 0 for r in responses), responses | |||||
| class TestDatasetUpdate: | |||||
| @pytest.mark.p3 | |||||
| def test_dataset_id_not_uuid(self, get_http_api_auth): | |||||
| payload = {"name": "dataset_id_not_uuid"} | |||||
| res = update_dataset(get_http_api_auth, "not_uuid", payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be a valid UUID" in res["message"], res | |||||
| @pytest.mark.p3 | |||||
| def test_dataset_id_wrong_uuid(self, get_http_api_auth): | |||||
| payload = {"name": "wrong_uuid"} | |||||
| res = update_dataset(get_http_api_auth, "d94a8dc02c9711f0930f7fbc369eab6d", payload) | |||||
| assert res["code"] == 102, res | |||||
| assert "lacks permission for dataset" in res["message"], res | |||||
| @pytest.mark.p1 | |||||
| @given(name=valid_names()) | |||||
| @example("a" * 128) | |||||
| @settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture]) | |||||
| def test_name(self, get_http_api_auth, add_dataset_func, name): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"name": name} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 0, res | |||||
| res = list_datasets(get_http_api_auth) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["name"] == name, res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "embedding_model, expected_code, expected_message", | |||||
| "name, expected_message", | |||||
| [ | [ | ||||
| ("BAAI/bge-large-zh-v1.5", 0, ""), | |||||
| ("maidalun1020/bce-embedding-base_v1", 0, ""), | |||||
| ( | |||||
| "other_embedding_model", | |||||
| 102, | |||||
| "`embedding_model` other_embedding_model doesn't exist", | |||||
| ), | |||||
| (None, 102, "`embedding_model` can't be empty"), | |||||
| ("", "String should have at least 1 character"), | |||||
| (" ", "String should have at least 1 character"), | |||||
| ("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"), | |||||
| (0, "Input should be a valid string"), | |||||
| (None, "Input should be a valid string"), | |||||
| ], | ], | ||||
| ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], | |||||
| ) | ) | ||||
| def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model, expected_code, expected_message): | |||||
| def test_name_invalid(self, get_http_api_auth, add_dataset_func, name, expected_message): | |||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| res = update_dataset(get_http_api_auth, dataset_id, {"embedding_model": embedding_model}) | |||||
| assert res["code"] == expected_code | |||||
| if expected_code == 0: | |||||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||||
| assert res["data"][0]["embedding_model"] == embedding_model | |||||
| else: | |||||
| assert res["message"] == expected_message | |||||
| payload = {"name": name} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert expected_message in res["message"], res | |||||
| @pytest.mark.p3 | |||||
| def test_name_duplicated(self, get_http_api_auth, add_datasets_func): | |||||
| dataset_ids = add_datasets_func[0] | |||||
| name = "dataset_1" | |||||
| payload = {"name": name} | |||||
| res = update_dataset(get_http_api_auth, dataset_ids, payload) | |||||
| assert res["code"] == 102, res | |||||
| assert res["message"] == f"Dataset name '{name}' already exists", res | |||||
| @pytest.mark.p3 | |||||
| def test_name_case_insensitive(self, get_http_api_auth, add_datasets_func): | |||||
| dataset_id = add_datasets_func[0] | |||||
| name = "DATASET_1" | |||||
| payload = {"name": name} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 102, res | |||||
| assert res["message"] == f"Dataset name '{name}' already exists", res | |||||
| @pytest.mark.p2 | |||||
| def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path): | |||||
| dataset_id = add_dataset_func | |||||
| fn = create_image_file(tmp_path / "ragflow_test.png") | |||||
| payload = { | |||||
| "avatar": f"data:image/png;base64,{encode_avatar(fn)}", | |||||
| } | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 0, res | |||||
| res = list_datasets(get_http_api_auth) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res | |||||
| @pytest.mark.p2 | |||||
| def test_avatar_exceeds_limit_length(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"avatar": "a" * 65536} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "String should have at most 65535 characters" in res["message"], res | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "chunk_method, expected_code, expected_message", | |||||
| "name, avatar_prefix, expected_message", | |||||
| [ | [ | ||||
| ("naive", 0, ""), | |||||
| ("manual", 0, ""), | |||||
| ("qa", 0, ""), | |||||
| ("table", 0, ""), | |||||
| ("paper", 0, ""), | |||||
| ("book", 0, ""), | |||||
| ("laws", 0, ""), | |||||
| ("presentation", 0, ""), | |||||
| ("picture", 0, ""), | |||||
| ("one", 0, ""), | |||||
| ("email", 0, ""), | |||||
| ("tag", 0, ""), | |||||
| ("", 0, ""), | |||||
| ( | |||||
| "other_chunk_method", | |||||
| 102, | |||||
| "'other_chunk_method' is not in ['naive', 'manual', 'qa', 'table', 'paper', 'book', 'laws', 'presentation', 'picture', 'one', 'email', 'tag']", | |||||
| ), | |||||
| ("empty_prefix", "", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"), | |||||
| ("missing_comma", "data:image/png;base64", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"), | |||||
| ("unsupported_mine_type", "invalid_mine_prefix:image/png;base64,", "Invalid MIME prefix format. Must start with 'data:'"), | |||||
| ("invalid_mine_type", "data:unsupported_mine_type;base64,", "Unsupported MIME type. Allowed: ['image/jpeg', 'image/png']"), | |||||
| ], | ], | ||||
| ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"], | |||||
| ) | ) | ||||
| def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method, expected_code, expected_message): | |||||
| dataset_id = add_dataset_func | |||||
| res = update_dataset(get_http_api_auth, dataset_id, {"chunk_method": chunk_method}) | |||||
| assert res["code"] == expected_code | |||||
| if expected_code == 0: | |||||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||||
| if chunk_method != "": | |||||
| assert res["data"][0]["chunk_method"] == chunk_method | |||||
| else: | |||||
| assert res["data"][0]["chunk_method"] == "naive" | |||||
| else: | |||||
| assert res["message"] == expected_message | |||||
| def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path): | |||||
| def test_avatar_invalid_prefix(self, get_http_api_auth, add_dataset_func, tmp_path, name, avatar_prefix, expected_message): | |||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| fn = create_image_file(tmp_path / "ragflow_test.png") | fn = create_image_file(tmp_path / "ragflow_test.png") | ||||
| payload = {"avatar": encode_avatar(fn)} | |||||
| payload = { | |||||
| "name": name, | |||||
| "avatar": f"{avatar_prefix}{encode_avatar(fn)}", | |||||
| } | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | res = update_dataset(get_http_api_auth, dataset_id, payload) | ||||
| assert res["code"] == 0 | |||||
| assert res["code"] == 101, res | |||||
| assert expected_message in res["message"], res | |||||
| @pytest.mark.p3 | |||||
| def test_avatar_none(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"avatar": None} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 0, res | |||||
| res = list_datasets(get_http_api_auth) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["avatar"] is None, res | |||||
| @pytest.mark.p2 | |||||
| def test_description(self, get_http_api_auth, add_dataset_func): | def test_description(self, get_http_api_auth, add_dataset_func): | ||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| payload = {"description": "description"} | payload = {"description": "description"} | ||||
| assert res["code"] == 0 | assert res["code"] == 0 | ||||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | res = list_datasets(get_http_api_auth, {"id": dataset_id}) | ||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["description"] == "description" | assert res["data"][0]["description"] == "description" | ||||
| def test_pagerank(self, get_http_api_auth, add_dataset_func): | |||||
| @pytest.mark.p2 | |||||
| def test_description_exceeds_limit_length(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| payload = {"pagerank": 1} | |||||
| payload = {"description": "a" * 65536} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | res = update_dataset(get_http_api_auth, dataset_id, payload) | ||||
| assert res["code"] == 0 | |||||
| assert res["code"] == 101, res | |||||
| assert "String should have at most 65535 characters" in res["message"], res | |||||
| @pytest.mark.p3 | |||||
| def test_description_none(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"description": None} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 0, res | |||||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | res = list_datasets(get_http_api_auth, {"id": dataset_id}) | ||||
| assert res["data"][0]["pagerank"] == 1 | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["description"] is None | |||||
| def test_similarity_threshold(self, get_http_api_auth, add_dataset_func): | |||||
| @pytest.mark.p1 | |||||
| @pytest.mark.parametrize( | |||||
| "embedding_model", | |||||
| [ | |||||
| "BAAI/bge-large-zh-v1.5@BAAI", | |||||
| "maidalun1020/bce-embedding-base_v1@Youdao", | |||||
| "embedding-3@ZHIPU-AI", | |||||
| ], | |||||
| ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"], | |||||
| ) | |||||
| def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model): | |||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| payload = {"similarity_threshold": 1} | |||||
| payload = {"embedding_model": embedding_model} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | res = update_dataset(get_http_api_auth, dataset_id, payload) | ||||
| assert res["code"] == 0 | |||||
| assert res["code"] == 0, res | |||||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||||
| assert res["data"][0]["similarity_threshold"] == 1 | |||||
| res = list_datasets(get_http_api_auth) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["embedding_model"] == embedding_model, res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "name, embedding_model", | |||||
| [ | |||||
| ("unknown_llm_name", "unknown@ZHIPU-AI"), | |||||
| ("unknown_llm_factory", "embedding-3@unknown"), | |||||
| ("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"), | |||||
| ("tenant_no_auth", "text-embedding-3-small@OpenAI"), | |||||
| ], | |||||
| ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], | |||||
| ) | |||||
| def test_embedding_model_invalid(self, get_http_api_auth, add_dataset_func, name, embedding_model): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"name": name, "embedding_model": embedding_model} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| if "tenant_no_auth" in name: | |||||
| assert res["message"] == f"Unauthorized model: <{embedding_model}>", res | |||||
| else: | |||||
| assert res["message"] == f"Unsupported model: <{embedding_model}>", res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "name, embedding_model", | |||||
| [ | |||||
| ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), | |||||
| ("missing_model_name", "@BAAI"), | |||||
| ("missing_provider", "BAAI/bge-large-zh-v1.5@"), | |||||
| ("whitespace_only_model_name", " @BAAI"), | |||||
| ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), | |||||
| ], | |||||
| ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], | |||||
| ) | |||||
| def test_embedding_model_format(self, get_http_api_auth, add_dataset_func, name, embedding_model): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"name": name, "embedding_model": embedding_model} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| if name == "missing_at": | |||||
| assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res | |||||
| else: | |||||
| assert "Both model_name and provider must be non-empty strings" in res["message"], res | |||||
| @pytest.mark.p2 | |||||
| def test_embedding_model_none(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"embedding_model": None} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be a valid string" in res["message"], res | |||||
| @pytest.mark.p1 | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "permission, expected_code", | |||||
| "name, permission", | |||||
| [ | [ | ||||
| ("me", 0), | |||||
| ("team", 0), | |||||
| ("", 0), | |||||
| ("ME", 102), | |||||
| ("TEAM", 102), | |||||
| ("other_permission", 102), | |||||
| ("me", "me"), | |||||
| ("team", "team"), | |||||
| ("me_upercase", "ME"), | |||||
| ("team_upercase", "TEAM"), | |||||
| ], | ], | ||||
| ids=["me", "team", "me_upercase", "team_upercase"], | |||||
| ) | ) | ||||
| def test_permission(self, get_http_api_auth, add_dataset_func, permission, expected_code): | |||||
| def test_permission(self, get_http_api_auth, add_dataset_func, name, permission): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"name": name, "permission": permission} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 0, res | |||||
| res = list_datasets(get_http_api_auth) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["permission"] == permission.lower(), res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "permission", | |||||
| [ | |||||
| "", | |||||
| "unknown", | |||||
| list(), | |||||
| ], | |||||
| ids=["empty", "unknown", "type_error"], | |||||
| ) | |||||
| def test_permission_invalid(self, get_http_api_auth, add_dataset_func, permission): | |||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| payload = {"permission": permission} | payload = {"permission": permission} | ||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | res = update_dataset(get_http_api_auth, dataset_id, payload) | ||||
| assert res["code"] == expected_code | |||||
| assert res["code"] == 101 | |||||
| assert "Input should be 'me' or 'team'" in res["message"] | |||||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | |||||
| if expected_code == 0 and permission != "": | |||||
| assert res["data"][0]["permission"] == permission | |||||
| if permission == "": | |||||
| assert res["data"][0]["permission"] == "me" | |||||
| @pytest.mark.p3 | |||||
| def test_permission_none(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"name": "test_permission_none", "permission": None} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be 'me' or 'team'" in res["message"], res | |||||
| def test_vector_similarity_weight(self, get_http_api_auth, add_dataset_func): | |||||
| @pytest.mark.p1 | |||||
| @pytest.mark.parametrize( | |||||
| "chunk_method", | |||||
| [ | |||||
| "naive", | |||||
| "book", | |||||
| "email", | |||||
| "laws", | |||||
| "manual", | |||||
| "one", | |||||
| "paper", | |||||
| "picture", | |||||
| "presentation", | |||||
| "qa", | |||||
| "table", | |||||
| "tag", | |||||
| ], | |||||
| ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], | |||||
| ) | |||||
| def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method): | |||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| payload = {"vector_similarity_weight": 1} | |||||
| payload = {"chunk_method": chunk_method} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 0, res | |||||
| res = list_datasets(get_http_api_auth) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["chunk_method"] == chunk_method, res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "chunk_method", | |||||
| [ | |||||
| "", | |||||
| "unknown", | |||||
| list(), | |||||
| ], | |||||
| ids=["empty", "unknown", "type_error"], | |||||
| ) | |||||
| def test_chunk_method_invalid(self, get_http_api_auth, add_dataset_func, chunk_method): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"chunk_method": chunk_method} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res | |||||
| @pytest.mark.p3 | |||||
| def test_chunk_method_none(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"chunk_method": None} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) | |||||
| def test_pagerank(self, get_http_api_auth, add_dataset_func, pagerank): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"pagerank": pagerank} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | res = update_dataset(get_http_api_auth, dataset_id, payload) | ||||
| assert res["code"] == 0 | assert res["code"] == 0 | ||||
| res = list_datasets(get_http_api_auth, {"id": dataset_id}) | res = list_datasets(get_http_api_auth, {"id": dataset_id}) | ||||
| assert res["data"][0]["vector_similarity_weight"] == 1 | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["pagerank"] == pagerank | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "pagerank, expected_message", | |||||
| [ | |||||
| (-1, "Input should be greater than or equal to 0"), | |||||
| (101, "Input should be less than or equal to 100"), | |||||
| ], | |||||
| ids=["min_limit", "max_limit"], | |||||
| ) | |||||
| def test_pagerank_invalid(self, get_http_api_auth, add_dataset_func, pagerank, expected_message): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"pagerank": pagerank} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert expected_message in res["message"], res | |||||
| def test_invalid_dataset_id(self, get_http_api_auth): | |||||
| res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"}) | |||||
| assert res["code"] == 102 | |||||
| assert res["message"] == "You don't own the dataset" | |||||
| @pytest.mark.p3 | |||||
| def test_pagerank_none(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"pagerank": None} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be a valid integer" in res["message"], res | |||||
| @pytest.mark.p1 | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "payload", | |||||
| "parser_config", | |||||
| [ | [ | ||||
| {"chunk_count": 1}, | |||||
| {"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, | |||||
| {"create_time": 1741671443322}, | |||||
| {"created_by": "aa"}, | |||||
| {"document_count": 1}, | |||||
| {"id": "id"}, | |||||
| {"status": "1"}, | |||||
| {"tenant_id": "e57c1966f99211efb41e9e45646e0111"}, | |||||
| {"token_num": 1}, | |||||
| {"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, | |||||
| {"update_time": 1741671443339}, | |||||
| {"auto_keywords": 0}, | |||||
| {"auto_keywords": 16}, | |||||
| {"auto_keywords": 32}, | |||||
| {"auto_questions": 0}, | |||||
| {"auto_questions": 5}, | |||||
| {"auto_questions": 10}, | |||||
| {"chunk_token_num": 1}, | |||||
| {"chunk_token_num": 1024}, | |||||
| {"chunk_token_num": 2048}, | |||||
| {"delimiter": "\n"}, | |||||
| {"delimiter": " "}, | |||||
| {"html4excel": True}, | |||||
| {"html4excel": False}, | |||||
| {"layout_recognize": "DeepDOC"}, | |||||
| {"layout_recognize": "Plain Text"}, | |||||
| {"tag_kb_ids": ["1", "2"]}, | |||||
| {"topn_tags": 1}, | |||||
| {"topn_tags": 5}, | |||||
| {"topn_tags": 10}, | |||||
| {"filename_embd_weight": 0.1}, | |||||
| {"filename_embd_weight": 0.5}, | |||||
| {"filename_embd_weight": 1.0}, | |||||
| {"task_page_size": 1}, | |||||
| {"task_page_size": None}, | |||||
| {"pages": [[1, 100]]}, | |||||
| {"pages": None}, | |||||
| {"graphrag": {"use_graphrag": True}}, | |||||
| {"graphrag": {"use_graphrag": False}}, | |||||
| {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}, | |||||
| {"graphrag": {"method": "general"}}, | |||||
| {"graphrag": {"method": "light"}}, | |||||
| {"graphrag": {"community": True}}, | |||||
| {"graphrag": {"community": False}}, | |||||
| {"graphrag": {"resolution": True}}, | |||||
| {"graphrag": {"resolution": False}}, | |||||
| {"raptor": {"use_raptor": True}}, | |||||
| {"raptor": {"use_raptor": False}}, | |||||
| {"raptor": {"prompt": "Who are you?"}}, | |||||
| {"raptor": {"max_token": 1}}, | |||||
| {"raptor": {"max_token": 1024}}, | |||||
| {"raptor": {"max_token": 2048}}, | |||||
| {"raptor": {"threshold": 0.0}}, | |||||
| {"raptor": {"threshold": 0.5}}, | |||||
| {"raptor": {"threshold": 1.0}}, | |||||
| {"raptor": {"max_cluster": 1}}, | |||||
| {"raptor": {"max_cluster": 512}}, | |||||
| {"raptor": {"max_cluster": 1024}}, | |||||
| {"raptor": {"random_seed": 0}}, | |||||
| ], | |||||
| ids=[ | |||||
| "auto_keywords_min", | |||||
| "auto_keywords_mid", | |||||
| "auto_keywords_max", | |||||
| "auto_questions_min", | |||||
| "auto_questions_mid", | |||||
| "auto_questions_max", | |||||
| "chunk_token_num_min", | |||||
| "chunk_token_num_mid", | |||||
| "chunk_token_num_max", | |||||
| "delimiter", | |||||
| "delimiter_space", | |||||
| "html4excel_true", | |||||
| "html4excel_false", | |||||
| "layout_recognize_DeepDOC", | |||||
| "layout_recognize_navie", | |||||
| "tag_kb_ids", | |||||
| "topn_tags_min", | |||||
| "topn_tags_mid", | |||||
| "topn_tags_max", | |||||
| "filename_embd_weight_min", | |||||
| "filename_embd_weight_mid", | |||||
| "filename_embd_weight_max", | |||||
| "task_page_size_min", | |||||
| "task_page_size_None", | |||||
| "pages", | |||||
| "pages_none", | |||||
| "graphrag_true", | |||||
| "graphrag_false", | |||||
| "graphrag_entity_types", | |||||
| "graphrag_method_general", | |||||
| "graphrag_method_light", | |||||
| "graphrag_community_true", | |||||
| "graphrag_community_false", | |||||
| "graphrag_resolution_true", | |||||
| "graphrag_resolution_false", | |||||
| "raptor_true", | |||||
| "raptor_false", | |||||
| "raptor_prompt", | |||||
| "raptor_max_token_min", | |||||
| "raptor_max_token_mid", | |||||
| "raptor_max_token_max", | |||||
| "raptor_threshold_min", | |||||
| "raptor_threshold_mid", | |||||
| "raptor_threshold_max", | |||||
| "raptor_max_cluster_min", | |||||
| "raptor_max_cluster_mid", | |||||
| "raptor_max_cluster_max", | |||||
| "raptor_random_seed_min", | |||||
| ], | ], | ||||
| ) | ) | ||||
| def test_modify_read_only_field(self, get_http_api_auth, add_dataset_func, payload): | |||||
| def test_parser_config(self, get_http_api_auth, add_dataset_func, parser_config): | |||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| payload = {"parser_config": parser_config} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | res = update_dataset(get_http_api_auth, dataset_id, payload) | ||||
| assert res["code"] == 101 | |||||
| assert "is readonly" in res["message"] | |||||
| assert res["code"] == 0, res | |||||
| res = list_datasets(get_http_api_auth) | |||||
| assert res["code"] == 0, res | |||||
| for k, v in parser_config.items(): | |||||
| if isinstance(v, dict): | |||||
| for kk, vv in v.items(): | |||||
| assert res["data"][0]["parser_config"][k][kk] == vv, res | |||||
| else: | |||||
| assert res["data"][0]["parser_config"][k] == v, res | |||||
| def test_modify_unknown_field(self, get_http_api_auth, add_dataset_func): | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "parser_config, expected_message", | |||||
| [ | |||||
| ({"auto_keywords": -1}, "Input should be greater than or equal to 0"), | |||||
| ({"auto_keywords": 33}, "Input should be less than or equal to 32"), | |||||
| ({"auto_keywords": 3.14}, "Input should be a valid integer, got a number with a fractional part"), | |||||
| ({"auto_keywords": "string"}, "Input should be a valid integer, unable to parse string as an integer"), | |||||
| ({"auto_questions": -1}, "Input should be greater than or equal to 0"), | |||||
| ({"auto_questions": 11}, "Input should be less than or equal to 10"), | |||||
| ({"auto_questions": 3.14}, "Input should be a valid integer, got a number with a fractional part"), | |||||
| ({"auto_questions": "string"}, "Input should be a valid integer, unable to parse string as an integer"), | |||||
| ({"chunk_token_num": 0}, "Input should be greater than or equal to 1"), | |||||
| ({"chunk_token_num": 2049}, "Input should be less than or equal to 2048"), | |||||
| ({"chunk_token_num": 3.14}, "Input should be a valid integer, got a number with a fractional part"), | |||||
| ({"chunk_token_num": "string"}, "Input should be a valid integer, unable to parse string as an integer"), | |||||
| ({"delimiter": ""}, "String should have at least 1 character"), | |||||
| ({"html4excel": "string"}, "Input should be a valid boolean, unable to interpret input"), | |||||
| ({"tag_kb_ids": "1,2"}, "Input should be a valid list"), | |||||
| ({"tag_kb_ids": [1, 2]}, "Input should be a valid string"), | |||||
| ({"topn_tags": 0}, "Input should be greater than or equal to 1"), | |||||
| ({"topn_tags": 11}, "Input should be less than or equal to 10"), | |||||
| ({"topn_tags": 3.14}, "Input should be a valid integer, got a number with a fractional part"), | |||||
| ({"topn_tags": "string"}, "Input should be a valid integer, unable to parse string as an integer"), | |||||
| ({"filename_embd_weight": -1}, "Input should be greater than or equal to 0"), | |||||
| ({"filename_embd_weight": 1.1}, "Input should be less than or equal to 1"), | |||||
| ({"filename_embd_weight": "string"}, "Input should be a valid number, unable to parse string as a number"), | |||||
| ({"task_page_size": 0}, "Input should be greater than or equal to 1"), | |||||
| ({"task_page_size": 3.14}, "Input should be a valid integer, got a number with a fractional part"), | |||||
| ({"task_page_size": "string"}, "Input should be a valid integer, unable to parse string as an integer"), | |||||
| ({"pages": "1,2"}, "Input should be a valid list"), | |||||
| ({"pages": ["1,2"]}, "Input should be a valid list"), | |||||
| ({"pages": [["string1", "string2"]]}, "Input should be a valid integer, unable to parse string as an integer"), | |||||
| ({"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean, unable to interpret input"), | |||||
| ({"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"), | |||||
| ({"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"), | |||||
| ({"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"), | |||||
| ({"graphrag": {"method": None}}, "Input should be 'light' or 'general'"), | |||||
| ({"graphrag": {"community": "string"}}, "Input should be a valid boolean, unable to interpret input"), | |||||
| ({"graphrag": {"resolution": "string"}}, "Input should be a valid boolean, unable to interpret input"), | |||||
| ({"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean, unable to interpret input"), | |||||
| ({"raptor": {"prompt": ""}}, "String should have at least 1 character"), | |||||
| ({"raptor": {"prompt": " "}}, "String should have at least 1 character"), | |||||
| ({"raptor": {"max_token": 0}}, "Input should be greater than or equal to 1"), | |||||
| ({"raptor": {"max_token": 2049}}, "Input should be less than or equal to 2048"), | |||||
| ({"raptor": {"max_token": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), | |||||
| ({"raptor": {"max_token": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), | |||||
| ({"raptor": {"threshold": -0.1}}, "Input should be greater than or equal to 0"), | |||||
| ({"raptor": {"threshold": 1.1}}, "Input should be less than or equal to 1"), | |||||
| ({"raptor": {"threshold": "string"}}, "Input should be a valid number, unable to parse string as a number"), | |||||
| ({"raptor": {"max_cluster": 0}}, "Input should be greater than or equal to 1"), | |||||
| ({"raptor": {"max_cluster": 1025}}, "Input should be less than or equal to 1024"), | |||||
| ({"raptor": {"max_cluster": 3.14}}, "Input should be a valid integer, got a number with a fractional par"), | |||||
| ({"raptor": {"max_cluster": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), | |||||
| ({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"), | |||||
| ({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), | |||||
| ({"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), | |||||
| ({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"), | |||||
| ], | |||||
| ids=[ | |||||
| "auto_keywords_min_limit", | |||||
| "auto_keywords_max_limit", | |||||
| "auto_keywords_float_not_allowed", | |||||
| "auto_keywords_type_invalid", | |||||
| "auto_questions_min_limit", | |||||
| "auto_questions_max_limit", | |||||
| "auto_questions_float_not_allowed", | |||||
| "auto_questions_type_invalid", | |||||
| "chunk_token_num_min_limit", | |||||
| "chunk_token_num_max_limit", | |||||
| "chunk_token_num_float_not_allowed", | |||||
| "chunk_token_num_type_invalid", | |||||
| "delimiter_empty", | |||||
| "html4excel_type_invalid", | |||||
| "tag_kb_ids_not_list", | |||||
| "tag_kb_ids_int_in_list", | |||||
| "topn_tags_min_limit", | |||||
| "topn_tags_max_limit", | |||||
| "topn_tags_float_not_allowed", | |||||
| "topn_tags_type_invalid", | |||||
| "filename_embd_weight_min_limit", | |||||
| "filename_embd_weight_max_limit", | |||||
| "filename_embd_weight_type_invalid", | |||||
| "task_page_size_min_limit", | |||||
| "task_page_size_float_not_allowed", | |||||
| "task_page_size_type_invalid", | |||||
| "pages_not_list", | |||||
| "pages_not_list_in_list", | |||||
| "pages_not_int_list", | |||||
| "graphrag_type_invalid", | |||||
| "graphrag_entity_types_not_list", | |||||
| "graphrag_entity_types_not_str_in_list", | |||||
| "graphrag_method_unknown", | |||||
| "graphrag_method_none", | |||||
| "graphrag_community_type_invalid", | |||||
| "graphrag_resolution_type_invalid", | |||||
| "raptor_type_invalid", | |||||
| "raptor_prompt_empty", | |||||
| "raptor_prompt_space", | |||||
| "raptor_max_token_min_limit", | |||||
| "raptor_max_token_max_limit", | |||||
| "raptor_max_token_float_not_allowed", | |||||
| "raptor_max_token_type_invalid", | |||||
| "raptor_threshold_min_limit", | |||||
| "raptor_threshold_max_limit", | |||||
| "raptor_threshold_type_invalid", | |||||
| "raptor_max_cluster_min_limit", | |||||
| "raptor_max_cluster_max_limit", | |||||
| "raptor_max_cluster_float_not_allowed", | |||||
| "raptor_max_cluster_type_invalid", | |||||
| "raptor_random_seed_min_limit", | |||||
| "raptor_random_seed_float_not_allowed", | |||||
| "raptor_random_seed_type_invalid", | |||||
| "parser_config_type_invalid", | |||||
| ], | |||||
| ) | |||||
| def test_parser_config_invalid(self, get_http_api_auth, add_dataset_func, parser_config, expected_message): | |||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| res = update_dataset(get_http_api_auth, dataset_id, {"unknown_field": 0}) | |||||
| assert res["code"] == 100 | |||||
| payload = {"parser_config": parser_config} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert expected_message in res["message"], res | |||||
| @pytest.mark.p2 | |||||
| def test_parser_config_empty(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"parser_config": {}} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 0, res | |||||
| res = list_datasets(get_http_api_auth) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["parser_config"] == {} | |||||
| # @pytest.mark.p2 | |||||
| # def test_parser_config_unset(self, get_http_api_auth, add_dataset_func): | |||||
| # dataset_id = add_dataset_func | |||||
| # payload = {"name": "default_unset"} | |||||
| # res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| # assert res["code"] == 0, res | |||||
| # res = list_datasets(get_http_api_auth) | |||||
| # assert res["code"] == 0, res | |||||
| # assert res["data"][0]["parser_config"] == { | |||||
| # "chunk_token_num": 128, | |||||
| # "delimiter": r"\n", | |||||
| # "html4excel": False, | |||||
| # "layout_recognize": "DeepDOC", | |||||
| # "raptor": {"use_raptor": False}, | |||||
| # }, res | |||||
| @pytest.mark.p3 | @pytest.mark.p3 | ||||
| def test_concurrent_update(self, get_http_api_auth, add_dataset_func): | |||||
| def test_parser_config_none(self, get_http_api_auth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| payload = {"parser_config": None} | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be a valid dictionary or instance of ParserConfig" in res["message"], res | |||||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||||
| futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)] | |||||
| responses = [f.result() for f in futures] | |||||
| assert all(r["code"] == 0 for r in responses) | |||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "payload", | |||||
| [ | |||||
| {"id": "id"}, | |||||
| {"tenant_id": "e57c1966f99211efb41e9e45646e0111"}, | |||||
| {"created_by": "created_by"}, | |||||
| {"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, | |||||
| {"create_time": 1741671443322}, | |||||
| {"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"}, | |||||
| {"update_time": 1741671443339}, | |||||
| {"document_count": 1}, | |||||
| {"chunk_count": 1}, | |||||
| {"token_num": 1}, | |||||
| {"status": "1"}, | |||||
| {"unknown_field": "unknown_field"}, | |||||
| ], | |||||
| ) | |||||
| def test_unsupported_field(self, get_http_api_auth, add_dataset_func, payload): | |||||
| dataset_id = add_dataset_func | |||||
| res = update_dataset(get_http_api_auth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Extra inputs are not permitted" in res["message"], res |
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||
| from time import sleep | |||||
| import pytest | import pytest | ||||
| from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets | from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets | ||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) | document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) | ||||
| parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | ||||
| sleep(1) | |||||
| res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) | ||||
| assert res["code"] == 0 | assert res["code"] == 0 | ||||
| validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) | validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) |