### What problem does this PR solve? Optimize dataset validation and add function docs ### Type of change - [x] Refactoringtags/v0.19.0
| @@ -19,7 +19,6 @@ import logging | |||
| from flask import request | |||
| from peewee import OperationalError | |||
| from pydantic import ValidationError | |||
| from api import settings | |||
| from api.db import FileSource, StatusEnum | |||
| @@ -41,8 +40,9 @@ from api.utils.api_utils import ( | |||
| token_required, | |||
| valid, | |||
| valid_parser_config, | |||
| verify_embedding_availability, | |||
| ) | |||
| from api.utils.validation_utils import CreateDatasetReq, format_validation_error_message | |||
| from api.utils.validation_utils import CreateDatasetReq, validate_and_parse_json_request | |||
| @manager.route("/datasets", methods=["POST"]) # noqa: F821 | |||
| @@ -107,21 +107,14 @@ def create(tenant_id): | |||
| data: | |||
| type: object | |||
| """ | |||
| req_i = request.json | |||
| if not isinstance(req_i, dict): | |||
| return get_error_argument_result(f"Invalid request payload: expected object, got {type(req_i).__name__}") | |||
| try: | |||
| req_v = CreateDatasetReq(**req_i) | |||
| except ValidationError as e: | |||
| return get_error_argument_result(format_validation_error_message(e)) | |||
| # Field name transformations during model dump: | |||
| # | Original | Dump Output | | |||
| # |----------------|-------------| | |||
| # | embedding_model| embd_id | | |||
| # | chunk_method | parser_id | | |||
| req = req_v.model_dump(by_alias=True) | |||
| req, err = validate_and_parse_json_request(request, CreateDatasetReq) | |||
| if err is not None: | |||
| return get_error_argument_result(err) | |||
| try: | |||
| if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||
| @@ -146,21 +139,9 @@ def create(tenant_id): | |||
| if not req.get("embd_id"): | |||
| req["embd_id"] = t.embd_id | |||
| else: | |||
| builtin_embedding_models = [ | |||
| "BAAI/bge-large-zh-v1.5@BAAI", | |||
| "maidalun1020/bce-embedding-base_v1@Youdao", | |||
| ] | |||
| is_builtin_model = req["embd_id"] in builtin_embedding_models | |||
| try: | |||
| # model name must be model_name@model_factory | |||
| llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["embd_id"]) | |||
| is_tenant_model = TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="embedding") | |||
| is_supported_model = LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding") | |||
| if not (is_supported_model and (is_builtin_model or is_tenant_model)): | |||
| return get_error_argument_result(f"The embedding_model '{req['embd_id']}' is not supported") | |||
| except OperationalError as e: | |||
| logging.exception(e) | |||
| return get_error_data_result(message="Database operation failed") | |||
| ok, err = verify_embedding_availability(req["embd_id"], tenant_id) | |||
| if not ok: | |||
| return err | |||
| try: | |||
| if not KnowledgebaseService.save(**req): | |||
| @@ -13,22 +13,22 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import json | |||
| import os | |||
| from datetime import date | |||
| from enum import IntEnum, Enum | |||
| import json | |||
| from enum import Enum, IntEnum | |||
| import rag.utils | |||
| import rag.utils.es_conn | |||
| import rag.utils.infinity_conn | |||
| import rag.utils.opensearch_coon | |||
| import rag.utils | |||
| from rag.nlp import search | |||
| from graphrag import search as kg_search | |||
| from api.utils import get_base_config, decrypt_database_config | |||
| from api.constants import RAG_FLOW_SERVICE_NAME | |||
| from api.utils import decrypt_database_config, get_base_config | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from graphrag import search as kg_search | |||
| from rag.nlp import search | |||
| LIGHTEN = int(os.environ.get('LIGHTEN', "0")) | |||
| LIGHTEN = int(os.environ.get("LIGHTEN", "0")) | |||
| LLM = None | |||
| LLM_FACTORY = None | |||
| @@ -45,7 +45,7 @@ HOST_PORT = None | |||
| SECRET_KEY = None | |||
| FACTORY_LLM_INFOS = None | |||
| DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') | |||
| DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") | |||
| DATABASE = decrypt_database_config(name=DATABASE_TYPE) | |||
| # authentication | |||
| @@ -66,11 +66,13 @@ kg_retrievaler = None | |||
| # user registration switch | |||
| REGISTER_ENABLED = 1 | |||
| BUILTIN_EMBEDDING_MODELS = ["BAAI/bge-large-zh-v1.5@BAAI", "maidalun1020/bce-embedding-base_v1@Youdao"] | |||
| def init_settings(): | |||
| global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED | |||
| LIGHTEN = int(os.environ.get('LIGHTEN', "0")) | |||
| DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') | |||
| LIGHTEN = int(os.environ.get("LIGHTEN", "0")) | |||
| DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") | |||
| DATABASE = decrypt_database_config(name=DATABASE_TYPE) | |||
| LLM = get_base_config("user_default_llm", {}) | |||
| LLM_DEFAULT_MODELS = LLM.get("default_models", {}) | |||
| @@ -79,8 +81,8 @@ def init_settings(): | |||
| try: | |||
| REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) | |||
| except Exception: | |||
| pass | |||
| pass | |||
| try: | |||
| with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f: | |||
| FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"] | |||
| @@ -89,7 +91,7 @@ def init_settings(): | |||
| global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL | |||
| if not LIGHTEN: | |||
| EMBEDDING_MDL = "BAAI/bge-large-zh-v1.5@BAAI" | |||
| EMBEDDING_MDL = BUILTIN_EMBEDDING_MODELS[0] | |||
| if LLM_DEFAULT_MODELS: | |||
| CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL) | |||
| @@ -103,30 +105,25 @@ def init_settings(): | |||
| EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "") | |||
| RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "") | |||
| ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "") | |||
| IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + ( | |||
| f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "") | |||
| IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + (f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "") | |||
| global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY | |||
| API_KEY = LLM.get("api_key", "") | |||
| API_KEY = LLM.get("api_key") | |||
| PARSERS = LLM.get( | |||
| "parsers", | |||
| "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag") | |||
| "parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag" | |||
| ) | |||
| HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") | |||
| HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") | |||
| SECRET_KEY = get_base_config( | |||
| RAG_FLOW_SERVICE_NAME, | |||
| {}).get("secret_key", str(date.today())) | |||
| SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", str(date.today())) | |||
| global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG | |||
| # authentication | |||
| AUTHENTICATION_CONF = get_base_config("authentication", {}) | |||
| # client | |||
| CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get( | |||
| "client", {}).get( | |||
| "switch", False) | |||
| CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False) | |||
| HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") | |||
| GITHUB_OAUTH = get_base_config("oauth", {}).get("github") | |||
| FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") | |||
| @@ -134,7 +131,7 @@ def init_settings(): | |||
| OAUTH_CONFIG = get_base_config("oauth", {}) | |||
| global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler | |||
| DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch") | |||
| DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch") | |||
| # DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch") | |||
| lower_case_doc_engine = DOC_ENGINE.lower() | |||
| if lower_case_doc_engine == "elasticsearch": | |||
| @@ -36,11 +36,13 @@ from flask import ( | |||
| request as flask_request, | |||
| ) | |||
| from itsdangerous import URLSafeTimedSerializer | |||
| from peewee import OperationalError | |||
| from werkzeug.http import HTTP_STATUS_CODES | |||
| from api import settings | |||
| from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC | |||
| from api.db.db_models import APIToken | |||
| from api.db.services.llm_service import LLMService, TenantLLMService | |||
| from api.utils import CustomJSONEncoder, get_uuid, json_dumps | |||
| requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) | |||
| @@ -464,3 +466,55 @@ def check_duplicate_ids(ids, id_type="item"): | |||
| # Return unique IDs and error messages | |||
| return list(set(ids)), duplicate_messages | |||
| def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: | |||
| """Verifies availability of an embedding model for a specific tenant. | |||
| Implements a four-stage validation process: | |||
| 1. Model identifier parsing and validation | |||
| 2. System support verification | |||
| 3. Tenant authorization check | |||
| 4. Database operation error handling | |||
| Args: | |||
| embd_id (str): Unique identifier for the embedding model in format "model_name@factory" | |||
| tenant_id (str): Tenant identifier for access control | |||
| Returns: | |||
| tuple[bool, Response | None]: | |||
| - First element (bool): | |||
| - True: Model is available and authorized | |||
| - False: Validation failed | |||
| - Second element contains: | |||
| - None on success | |||
| - Error detail dict on failure | |||
| Raises: | |||
| ValueError: When model identifier format is invalid | |||
| OperationalError: When database connection fails (auto-handled) | |||
| Examples: | |||
| >>> verify_embedding_availability("text-embedding@openai", "tenant_123") | |||
| (True, None) | |||
| >>> verify_embedding_availability("invalid_model", "tenant_123") | |||
| (False, {'code': 101, 'message': "Unsupported model: <invalid_model>"}) | |||
| """ | |||
| try: | |||
| llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(embd_id) | |||
| if not LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"): | |||
| return False, get_error_argument_result(f"Unsupported model: <{embd_id}>") | |||
| # Tongyi-Qianwen is added to TenantLLM by default, but remains unusable with empty api_key | |||
| tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id) | |||
| is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms) | |||
| is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS | |||
| if not (is_builtin_model or is_tenant_model): | |||
| return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>") | |||
| except OperationalError as e: | |||
| logging.exception(e) | |||
| return False, get_error_data_result(message="Database operation failed") | |||
| return True, None | |||
| @@ -14,13 +14,102 @@ | |||
| # limitations under the License. | |||
| # | |||
| from enum import auto | |||
| from typing import Annotated, List, Optional | |||
| from typing import Annotated, Any | |||
| from flask import Request | |||
| from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator | |||
| from strenum import StrEnum | |||
| from werkzeug.exceptions import BadRequest, UnsupportedMediaType | |||
| from api.constants import DATASET_NAME_LIMIT | |||
| def validate_and_parse_json_request(request: Request, validator: type[BaseModel]) -> tuple[dict[str, Any] | None, str | None]: | |||
| """Validates and parses JSON requests through a multi-stage validation pipeline. | |||
| Implements a robust four-stage validation process: | |||
| 1. Content-Type verification (must be application/json) | |||
| 2. JSON syntax validation | |||
| 3. Payload structure type checking | |||
| 4. Pydantic model validation with error formatting | |||
| Args: | |||
| request (Request): Flask request object containing HTTP payload | |||
| Returns: | |||
| tuple[Dict[str, Any] | None, str | None]: | |||
| - First element: | |||
| - Validated dictionary on success | |||
| - None on validation failure | |||
| - Second element: | |||
| - None on success | |||
| - Diagnostic error message on failure | |||
| Raises: | |||
| UnsupportedMediaType: When Content-Type ≠ application/json | |||
| BadRequest: For structural JSON syntax errors | |||
| ValidationError: When payload violates Pydantic schema rules | |||
| Examples: | |||
| Successful validation: | |||
| ```python | |||
| # Input: {"name": "Dataset1", "format": "csv"} | |||
| # Returns: ({"name": "Dataset1", "format": "csv"}, None) | |||
| ``` | |||
| Invalid Content-Type: | |||
| ```python | |||
| # Returns: (None, "Unsupported content type: Expected application/json, got text/xml") | |||
| ``` | |||
| Malformed JSON: | |||
| ```python | |||
| # Returns: (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding") | |||
| ``` | |||
| """ | |||
| try: | |||
| payload = request.get_json() or {} | |||
| except UnsupportedMediaType: | |||
| return None, f"Unsupported content type: Expected application/json, got {request.content_type}" | |||
| except BadRequest: | |||
| return None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding" | |||
| if not isinstance(payload, dict): | |||
| return None, f"Invalid request payload: expected object, got {type(payload).__name__}" | |||
| try: | |||
| validated_request = validator(**payload) | |||
| except ValidationError as e: | |||
| return None, format_validation_error_message(e) | |||
| parsed_payload = validated_request.model_dump(by_alias=True) | |||
| return parsed_payload, None | |||
| def format_validation_error_message(e: ValidationError) -> str: | |||
| """Formats validation errors into a standardized string format. | |||
| Processes pydantic ValidationError objects to create human-readable error messages | |||
| containing field locations, error descriptions, and input values. | |||
| Args: | |||
| e (ValidationError): The validation error instance containing error details | |||
| Returns: | |||
| str: Formatted error messages joined by newlines. Each line contains: | |||
| - Field path (dot-separated) | |||
| - Error message | |||
| - Truncated input value (max 128 chars) | |||
| Example: | |||
| >>> try: | |||
| ... UserModel(name=123, email="invalid") | |||
| ... except ValidationError as e: | |||
| ... print(format_validation_error_message(e)) | |||
| Field: <name> - Message: <Input should be a valid string> - Value: <123> | |||
| Field: <email> - Message: <value is not a valid email address> - Value: <invalid> | |||
| """ | |||
| error_messages = [] | |||
| for error in e.errors(): | |||
| @@ -86,7 +175,7 @@ class RaptorConfig(Base): | |||
| class GraphragConfig(Base): | |||
| use_graphrag: bool = Field(default=False) | |||
| entity_types: List[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"]) | |||
| entity_types: list[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"]) | |||
| method: GraphragMethodEnum = Field(default=GraphragMethodEnum.light) | |||
| community: bool = Field(default=False) | |||
| resolution: bool = Field(default=False) | |||
| @@ -97,30 +186,59 @@ class ParserConfig(Base): | |||
| auto_questions: int = Field(default=0, ge=0, le=10) | |||
| chunk_token_num: int = Field(default=128, ge=1, le=2048) | |||
| delimiter: str = Field(default=r"\n", min_length=1) | |||
| graphrag: Optional[GraphragConfig] = None | |||
| graphrag: GraphragConfig | None = None | |||
| html4excel: bool = False | |||
| layout_recognize: str = "DeepDOC" | |||
| raptor: Optional[RaptorConfig] = None | |||
| tag_kb_ids: List[str] = Field(default_factory=list) | |||
| raptor: RaptorConfig | None = None | |||
| tag_kb_ids: list[str] = Field(default_factory=list) | |||
| topn_tags: int = Field(default=1, ge=1, le=10) | |||
| filename_embd_weight: Optional[float] = Field(default=None, ge=0.0, le=1.0) | |||
| task_page_size: Optional[int] = Field(default=None, ge=1) | |||
| pages: Optional[List[List[int]]] = None | |||
| filename_embd_weight: float | None = Field(default=None, ge=0.0, le=1.0) | |||
| task_page_size: int | None = Field(default=None, ge=1) | |||
| pages: list[list[int]] | None = None | |||
| class CreateDatasetReq(Base): | |||
| name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=128), Field(...)] | |||
| avatar: Optional[str] = Field(default=None, max_length=65535) | |||
| description: Optional[str] = Field(default=None, max_length=65535) | |||
| embedding_model: Annotated[Optional[str], StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")] | |||
| name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] | |||
| avatar: str | None = Field(default=None, max_length=65535) | |||
| description: str | None = Field(default=None, max_length=65535) | |||
| embedding_model: Annotated[str | None, StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")] | |||
| permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)] | |||
| chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")] | |||
| pagerank: int = Field(default=0, ge=0, le=100) | |||
| parser_config: Optional[ParserConfig] = Field(default=None) | |||
| parser_config: ParserConfig | None = Field(default=None) | |||
| @field_validator("avatar") | |||
| @classmethod | |||
| def validate_avatar_base64(cls, v: str) -> str: | |||
| """Validates Base64-encoded avatar string format and MIME type compliance. | |||
| Implements a three-stage validation workflow: | |||
| 1. MIME prefix existence check | |||
| 2. MIME type format validation | |||
| 3. Supported type verification | |||
| Args: | |||
| v (str): Raw avatar field value | |||
| Returns: | |||
| str: Validated Base64 string | |||
| Raises: | |||
| ValueError: For structural errors in these cases: | |||
| - Missing MIME prefix header | |||
| - Invalid MIME prefix format | |||
| - Unsupported image MIME type | |||
| Example: | |||
| ```python | |||
| # Valid case | |||
| CreateDatasetReq(avatar="...") | |||
| # Invalid cases | |||
| CreateDatasetReq(avatar="image/jpeg;base64,...") # Missing 'data:' prefix | |||
| CreateDatasetReq(avatar="data:video/mp4;base64,...") # Unsupported MIME type | |||
| ``` | |||
| """ | |||
| if v is None: | |||
| return v | |||
| @@ -141,22 +259,83 @@ class CreateDatasetReq(Base): | |||
| @field_validator("embedding_model", mode="after") | |||
| @classmethod | |||
| def validate_embedding_model(cls, v: str) -> str: | |||
| """Validates embedding model identifier format compliance. | |||
| Validation pipeline: | |||
| 1. Structural format verification | |||
| 2. Component non-empty check | |||
| 3. Value normalization | |||
| Args: | |||
| v (str): Raw model identifier | |||
| Returns: | |||
| str: Validated <model_name>@<provider> format | |||
| Raises: | |||
| ValueError: For these violations: | |||
| - Missing @ separator | |||
| - Empty model_name/provider | |||
| - Invalid component structure | |||
| Examples: | |||
| Valid: "text-embedding-3-large@openai" | |||
| Invalid: "invalid_model" (no @) | |||
| Invalid: "@openai" (empty model_name) | |||
| Invalid: "text-embedding-3-large@" (empty provider) | |||
| """ | |||
| if "@" not in v: | |||
| raise ValueError("Embedding model must be xxx@yyy") | |||
| raise ValueError("Embedding model identifier must follow <model_name>@<provider> format") | |||
| components = v.split("@", 1) | |||
| if len(components) != 2 or not all(components): | |||
| raise ValueError("Both model_name and provider must be non-empty strings") | |||
| model_name, provider = components | |||
| if not model_name.strip() or not provider.strip(): | |||
| raise ValueError("Model name and provider cannot be whitespace-only strings") | |||
| return v | |||
| @field_validator("permission", mode="before") | |||
| @classmethod | |||
| def permission_auto_lowercase(cls, v: str) -> str: | |||
| if isinstance(v, str): | |||
| return v.lower() | |||
| return v | |||
| """Normalize permission input to lowercase for consistent PermissionEnum matching. | |||
| Args: | |||
| v (str): Raw input value for the permission field | |||
| Returns: | |||
| Lowercase string if input is string type, otherwise returns original value | |||
| Behavior: | |||
| - Converts string inputs to lowercase (e.g., "ME" → "me") | |||
| - Non-string values pass through unchanged | |||
| - Works in validation pre-processing stage (before enum conversion) | |||
| """ | |||
| return v.lower() if isinstance(v, str) else v | |||
| @field_validator("parser_config", mode="after") | |||
| @classmethod | |||
| def validate_parser_config_json_length(cls, v: Optional[ParserConfig]) -> Optional[ParserConfig]: | |||
| if v is not None: | |||
| json_str = v.model_dump_json() | |||
| if len(json_str) > 65535: | |||
| raise ValueError("Parser config have at most 65535 characters") | |||
| def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None: | |||
| """Validates serialized JSON length constraints for parser configuration. | |||
| Implements a three-stage validation workflow: | |||
| 1. Null check - bypass validation for empty configurations | |||
| 2. Model serialization - convert Pydantic model to JSON string | |||
| 3. Size verification - enforce maximum allowed payload size | |||
| Args: | |||
| v (ParserConfig | None): Raw parser configuration object | |||
| Returns: | |||
| ParserConfig | None: Validated configuration object | |||
| Raises: | |||
| ValueError: When serialized JSON exceeds 65,535 characters | |||
| """ | |||
| if v is None: | |||
| return v | |||
| if (json_str := v.model_dump_json()) and len(json_str) > 65535: | |||
| raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}") | |||
| return v | |||
| @@ -39,23 +39,23 @@ SESSION_WITH_CHAT_NAME_LIMIT = 255 | |||
| # DATASET MANAGEMENT | |||
| def create_dataset(auth, payload=None): | |||
| res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload) | |||
| def create_dataset(auth, payload=None, headers=HEADERS, data=None): | |||
| res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data) | |||
| return res.json() | |||
| def list_datasets(auth, params=None): | |||
| res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, params=params) | |||
| def list_datasets(auth, params=None, headers=HEADERS): | |||
| res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params) | |||
| return res.json() | |||
| def update_dataset(auth, dataset_id, payload=None): | |||
| res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=HEADERS, auth=auth, json=payload) | |||
| def update_dataset(auth, dataset_id, payload=None, headers=HEADERS): | |||
| res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload) | |||
| return res.json() | |||
| def delete_datasets(auth, payload=None): | |||
| res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload) | |||
| def delete_datasets(auth, payload=None, headers=HEADERS): | |||
| res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload) | |||
| return res.json() | |||
| @@ -98,6 +98,25 @@ class TestDatasetCreation: | |||
| assert res["code"] == 101, res | |||
| assert res["message"] == f"Dataset name '{name.lower()}' already exists", res | |||
| def test_bad_content_type(self, get_http_api_auth): | |||
| BAD_CONTENT_TYPE = "text/xml" | |||
| res = create_dataset(get_http_api_auth, {"name": "name"}, {"Content-Type": BAD_CONTENT_TYPE}) | |||
| assert res["code"] == 101, res | |||
| assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message", | |||
| [ | |||
| ("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"), | |||
| ('"a"', "Invalid request payload: expected objec"), | |||
| ], | |||
| ids=["malformed_json_syntax", "invalid_request_payload_type"], | |||
| ) | |||
| def test_bad_payload(self, get_http_api_auth, payload, expected_message): | |||
| res = create_dataset(get_http_api_auth, data=payload) | |||
| assert res["code"] == 101, res | |||
| assert expected_message in res["message"], res | |||
| def test_avatar(self, get_http_api_auth, tmp_path): | |||
| fn = create_image_file(tmp_path / "ragflow_test.png") | |||
| payload = { | |||
| @@ -158,7 +177,7 @@ class TestDatasetCreation: | |||
| ("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"), | |||
| ("embedding_model_default", None), | |||
| ], | |||
| ids=["builtin_baai", "builtin_youdao", "tenant__zhipu", "default"], | |||
| ids=["builtin_baai", "builtin_youdao", "tenant_zhipu", "default"], | |||
| ) | |||
| def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model): | |||
| if embedding_model is None: | |||
| @@ -178,29 +197,39 @@ class TestDatasetCreation: | |||
| [ | |||
| ("unknown_llm_name", "unknown@ZHIPU-AI"), | |||
| ("unknown_llm_factory", "embedding-3@unknown"), | |||
| ("tenant_no_auth", "deepseek-chat@DeepSeek"), | |||
| ("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"), | |||
| ("tenant_no_auth", "text-embedding-3-small@OpenAI"), | |||
| ], | |||
| ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth"], | |||
| ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"], | |||
| ) | |||
| def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model): | |||
| payload = {"name": name, "embedding_model": embedding_model} | |||
| res = create_dataset(get_http_api_auth, payload) | |||
| assert res["code"] == 101, res | |||
| assert res["message"] == f"The embedding_model '{embedding_model}' is not supported", res | |||
| if "tenant_no_auth" in name: | |||
| assert res["message"] == f"Unauthorized model: <{embedding_model}>", res | |||
| else: | |||
| assert res["message"] == f"Unsupported model: <{embedding_model}>", res | |||
| @pytest.mark.parametrize( | |||
| "name, embedding_model", | |||
| [ | |||
| ("builtin_missing_at", "BAAI/bge-large-zh-v1.5"), | |||
| ("tenant_missing_at", "embedding-3ZHIPU-AI"), | |||
| ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), | |||
| ("missing_model_name", "@BAAI"), | |||
| ("missing_provider", "BAAI/bge-large-zh-v1.5@"), | |||
| ("whitespace_only_model_name", " @BAAI"), | |||
| ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), | |||
| ], | |||
| ids=["builtin_missing_at", "tenant_missing_at"], | |||
| ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], | |||
| ) | |||
| def test_embedding_model_missing_at(self, get_http_api_auth, name, embedding_model): | |||
| def test_embedding_model_format(self, get_http_api_auth, name, embedding_model): | |||
| payload = {"name": name, "embedding_model": embedding_model} | |||
| res = create_dataset(get_http_api_auth, payload) | |||
| assert res["code"] == 101, res | |||
| assert "Embedding model must be xxx@yyy" in res["message"], res | |||
| if name == "missing_at": | |||
| assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res | |||
| else: | |||
| assert "Both model_name and provider must be non-empty strings" in res["message"], res | |||
| @pytest.mark.parametrize( | |||
| "name, permission", | |||
| @@ -485,7 +514,7 @@ class TestDatasetCreation: | |||
| ("raptor_random_seed_min_limit", {"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"), | |||
| ("raptor_random_seed_float_not_allowed", {"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"), | |||
| ("raptor_random_seed_type_invalid", {"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"), | |||
| ("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config have at most 65535 characters"), | |||
| ("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"), | |||
| ], | |||
| ids=[ | |||
| "auto_keywords_min_limit", | |||