You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

validation_utils.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from enum import auto
  17. from typing import Annotated, List, Optional
  18. from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
  19. from strenum import StrEnum
  20. def format_validation_error_message(e: ValidationError) -> str:
  21. error_messages = []
  22. for error in e.errors():
  23. field = ".".join(map(str, error["loc"]))
  24. msg = error["msg"]
  25. input_val = error["input"]
  26. input_str = str(input_val)
  27. if len(input_str) > 128:
  28. input_str = input_str[:125] + "..."
  29. error_msg = f"Field: <{field}> - Message: <{msg}> - Value: <{input_str}>"
  30. error_messages.append(error_msg)
  31. return "\n".join(error_messages)
  32. class PermissionEnum(StrEnum):
  33. me = auto()
  34. team = auto()
  35. class ChunkMethodnEnum(StrEnum):
  36. naive = auto()
  37. book = auto()
  38. email = auto()
  39. laws = auto()
  40. manual = auto()
  41. one = auto()
  42. paper = auto()
  43. picture = auto()
  44. presentation = auto()
  45. qa = auto()
  46. table = auto()
  47. tag = auto()
  48. class GraphragMethodEnum(StrEnum):
  49. light = auto()
  50. general = auto()
  51. class Base(BaseModel):
  52. class Config:
  53. extra = "forbid"
  54. json_schema_extra = {"charset": "utf8mb4", "collation": "utf8mb4_0900_ai_ci"}
  55. class RaptorConfig(Base):
  56. use_raptor: bool = Field(default=False)
  57. prompt: Annotated[
  58. str,
  59. StringConstraints(strip_whitespace=True, min_length=1),
  60. Field(
  61. default="Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize."
  62. ),
  63. ]
  64. max_token: int = Field(default=256, ge=1, le=2048)
  65. threshold: float = Field(default=0.1, ge=0.0, le=1.0)
  66. max_cluster: int = Field(default=64, ge=1, le=1024)
  67. random_seed: int = Field(default=0, ge=0)
  68. class GraphragConfig(Base):
  69. use_graphrag: bool = Field(default=False)
  70. entity_types: List[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])
  71. method: GraphragMethodEnum = Field(default=GraphragMethodEnum.light)
  72. community: bool = Field(default=False)
  73. resolution: bool = Field(default=False)
  74. class ParserConfig(Base):
  75. auto_keywords: int = Field(default=0, ge=0, le=32)
  76. auto_questions: int = Field(default=0, ge=0, le=10)
  77. chunk_token_num: int = Field(default=128, ge=1, le=2048)
  78. delimiter: str = Field(default=r"\n", min_length=1)
  79. graphrag: Optional[GraphragConfig] = None
  80. html4excel: bool = False
  81. layout_recognize: str = "DeepDOC"
  82. raptor: Optional[RaptorConfig] = None
  83. tag_kb_ids: List[str] = Field(default_factory=list)
  84. topn_tags: int = Field(default=1, ge=1, le=10)
  85. filename_embd_weight: Optional[float] = Field(default=None, ge=0.0, le=1.0)
  86. task_page_size: Optional[int] = Field(default=None, ge=1)
  87. pages: Optional[List[List[int]]] = None
  88. class CreateDatasetReq(Base):
  89. name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=128), Field(...)]
  90. avatar: Optional[str] = Field(default=None, max_length=65535)
  91. description: Optional[str] = Field(default=None, max_length=65535)
  92. embedding_model: Annotated[Optional[str], StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")]
  93. permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
  94. chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
  95. pagerank: int = Field(default=0, ge=0, le=100)
  96. parser_config: Optional[ParserConfig] = Field(default=None)
  97. @field_validator("avatar")
  98. @classmethod
  99. def validate_avatar_base64(cls, v: str) -> str:
  100. if v is None:
  101. return v
  102. if "," in v:
  103. prefix, _ = v.split(",", 1)
  104. if not prefix.startswith("data:"):
  105. raise ValueError("Invalid MIME prefix format. Must start with 'data:'")
  106. mime_type = prefix[5:].split(";")[0]
  107. supported_mime_types = ["image/jpeg", "image/png"]
  108. if mime_type not in supported_mime_types:
  109. raise ValueError(f"Unsupported MIME type. Allowed: {supported_mime_types}")
  110. return v
  111. else:
  112. raise ValueError("Missing MIME prefix. Expected format: data:<mime>;base64,<data>")
  113. @field_validator("embedding_model", mode="after")
  114. @classmethod
  115. def validate_embedding_model(cls, v: str) -> str:
  116. if "@" not in v:
  117. raise ValueError("Embedding model must be xxx@yyy")
  118. return v
  119. @field_validator("permission", mode="before")
  120. @classmethod
  121. def permission_auto_lowercase(cls, v: str) -> str:
  122. if isinstance(v, str):
  123. return v.lower()
  124. return v
  125. @field_validator("parser_config", mode="after")
  126. @classmethod
  127. def validate_parser_config_json_length(cls, v: Optional[ParserConfig]) -> Optional[ParserConfig]:
  128. if v is not None:
  129. json_str = v.model_dump_json()
  130. if len(json_str) > 65535:
  131. raise ValueError("Parser config have at most 65535 characters")
  132. return v