### What problem does this PR solve? Previous: - Defaulted to hardcoded model 'BAAI/bge-large-zh-v1.5@BAAI' - Did not respect user-configured default embedding_model Now: - Correctly prioritizes user-configured default embedding_model Other: - Make embedding_model optional in CreateDatasetReq with proper None handling - Add default embedding model fallback in dataset update when empty - Enhance validation utils to handle None values and string normalization - Update SDK default embedding model to None to match API changes - Adjust related test cases to reflect new validation rules ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)tags/v0.20.0
| return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") | return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") | ||||
| if "embd_id" in req: | if "embd_id" in req: | ||||
| if not req["embd_id"]: | |||||
| req["embd_id"] = kb.embd_id | |||||
| if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: | if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: | ||||
| return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") | return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") | ||||
| ok, err = verify_embedding_availability(req["embd_id"], tenant_id) | ok, err = verify_embedding_availability(req["embd_id"], tenant_id) |
| name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] | name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] | ||||
| avatar: str | None = Field(default=None, max_length=65535) | avatar: str | None = Field(default=None, max_length=65535) | ||||
| description: str | None = Field(default=None, max_length=65535) | description: str | None = Field(default=None, max_length=65535) | ||||
| embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] | |||||
| embedding_model: str | None = Field(default=None, max_length=255, serialization_alias="embd_id") | |||||
| permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) | permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) | ||||
| chunk_method: ChunkMethodEnum = Field(default=ChunkMethodEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") | chunk_method: ChunkMethodEnum = Field(default=ChunkMethodEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") | ||||
| parser_config: ParserConfig | None = Field(default=None) | parser_config: ParserConfig | None = Field(default=None) | ||||
| else: | else: | ||||
| raise PydanticCustomError("format_invalid", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>") | raise PydanticCustomError("format_invalid", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>") | ||||
| @field_validator("embedding_model", mode="before") | |||||
| @classmethod | |||||
| def normalize_embedding_model(cls, v: Any) -> Any: | |||||
| if isinstance(v, str): | |||||
| return v.strip() | |||||
| return v | |||||
| @field_validator("embedding_model", mode="after") | @field_validator("embedding_model", mode="after") | ||||
| @classmethod | @classmethod | ||||
| def validate_embedding_model(cls, v: str) -> str: | |||||
| def validate_embedding_model(cls, v: str | None) -> str | None: | |||||
| """ | """ | ||||
| Validates embedding model identifier format compliance. | Validates embedding model identifier format compliance. | ||||
| Invalid: "@openai" (empty model_name) | Invalid: "@openai" (empty model_name) | ||||
| Invalid: "text-embedding-3-large@" (empty provider) | Invalid: "text-embedding-3-large@" (empty provider) | ||||
| """ | """ | ||||
| if "@" not in v: | |||||
| raise PydanticCustomError("format_invalid", "Embedding model identifier must follow <model_name>@<provider> format") | |||||
| if isinstance(v, str): | |||||
| if "@" not in v: | |||||
| raise PydanticCustomError("format_invalid", "Embedding model identifier must follow <model_name>@<provider> format") | |||||
| components = v.split("@", 1) | |||||
| if len(components) != 2 or not all(components): | |||||
| raise PydanticCustomError("format_invalid", "Both model_name and provider must be non-empty strings") | |||||
| components = v.split("@", 1) | |||||
| if len(components) != 2 or not all(components): | |||||
| raise PydanticCustomError("format_invalid", "Both model_name and provider must be non-empty strings") | |||||
| model_name, provider = components | |||||
| if not model_name.strip() or not provider.strip(): | |||||
| raise PydanticCustomError("format_invalid", "Model name and provider cannot be whitespace-only strings") | |||||
| model_name, provider = components | |||||
| if not model_name.strip() or not provider.strip(): | |||||
| raise PydanticCustomError("format_invalid", "Model name and provider cannot be whitespace-only strings") | |||||
| return v | return v | ||||
| @field_validator("permission", mode="before") | @field_validator("permission", mode="before") |
| name: str, | name: str, | ||||
| avatar: Optional[str] = None, | avatar: Optional[str] = None, | ||||
| description: Optional[str] = None, | description: Optional[str] = None, | ||||
| embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", | |||||
| embedding_model: Optional[str] = None, | |||||
| permission: str = "me", | permission: str = "me", | ||||
| chunk_method: str = "naive", | chunk_method: str = "naive", | ||||
| parser_config: Optional[DataSet.ParserConfig] = None, | parser_config: Optional[DataSet.ParserConfig] = None, |
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, embedding_model", | "name, embedding_model", | ||||
| [ | [ | ||||
| ("empty", ""), | |||||
| ("space", " "), | |||||
| ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), | ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), | ||||
| ("missing_model_name", "@BAAI"), | ("missing_model_name", "@BAAI"), | ||||
| ("missing_provider", "BAAI/bge-large-zh-v1.5@"), | ("missing_provider", "BAAI/bge-large-zh-v1.5@"), | ||||
| ("whitespace_only_model_name", " @BAAI"), | ("whitespace_only_model_name", " @BAAI"), | ||||
| ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), | ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), | ||||
| ], | ], | ||||
| ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], | |||||
| ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], | |||||
| ) | ) | ||||
| def test_embedding_model_format(self, HttpApiAuth, name, embedding_model): | def test_embedding_model_format(self, HttpApiAuth, name, embedding_model): | ||||
| payload = {"name": name, "embedding_model": embedding_model} | payload = {"name": name, "embedding_model": embedding_model} | ||||
| res = create_dataset(HttpApiAuth, payload) | res = create_dataset(HttpApiAuth, payload) | ||||
| assert res["code"] == 101, res | assert res["code"] == 101, res | ||||
| if name == "missing_at": | |||||
| if name in ["empty", "space", "missing_at"]: | |||||
| assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res | assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res | ||||
| else: | else: | ||||
| assert "Both model_name and provider must be non-empty strings" in res["message"], res | assert "Both model_name and provider must be non-empty strings" in res["message"], res | ||||
| def test_embedding_model_none(self, HttpApiAuth): | def test_embedding_model_none(self, HttpApiAuth): | ||||
| payload = {"name": "embedding_model_none", "embedding_model": None} | payload = {"name": "embedding_model_none", "embedding_model": None} | ||||
| res = create_dataset(HttpApiAuth, payload) | res = create_dataset(HttpApiAuth, payload) | ||||
| assert res["code"] == 101, res | |||||
| assert "Input should be a valid string" in res["message"], res | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res | |||||
| @pytest.mark.p1 | @pytest.mark.p1 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( |
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, embedding_model", | "name, embedding_model", | ||||
| [ | [ | ||||
| ("empty", ""), | |||||
| ("space", " "), | |||||
| ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), | ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), | ||||
| ("missing_model_name", "@BAAI"), | ("missing_model_name", "@BAAI"), | ||||
| ("missing_provider", "BAAI/bge-large-zh-v1.5@"), | ("missing_provider", "BAAI/bge-large-zh-v1.5@"), | ||||
| ("whitespace_only_model_name", " @BAAI"), | ("whitespace_only_model_name", " @BAAI"), | ||||
| ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), | ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), | ||||
| ], | ], | ||||
| ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], | |||||
| ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], | |||||
| ) | ) | ||||
| def test_embedding_model_format(self, HttpApiAuth, add_dataset_func, name, embedding_model): | def test_embedding_model_format(self, HttpApiAuth, add_dataset_func, name, embedding_model): | ||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| payload = {"name": name, "embedding_model": embedding_model} | payload = {"name": name, "embedding_model": embedding_model} | ||||
| res = update_dataset(HttpApiAuth, dataset_id, payload) | res = update_dataset(HttpApiAuth, dataset_id, payload) | ||||
| assert res["code"] == 101, res | assert res["code"] == 101, res | ||||
| if name == "missing_at": | |||||
| if name in ["empty", "space", "missing_at"]: | |||||
| assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res | assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res | ||||
| else: | else: | ||||
| assert "Both model_name and provider must be non-empty strings" in res["message"], res | assert "Both model_name and provider must be non-empty strings" in res["message"], res | ||||
| dataset_id = add_dataset_func | dataset_id = add_dataset_func | ||||
| payload = {"embedding_model": None} | payload = {"embedding_model": None} | ||||
| res = update_dataset(HttpApiAuth, dataset_id, payload) | res = update_dataset(HttpApiAuth, dataset_id, payload) | ||||
| assert res["code"] == 101, res | |||||
| assert "Input should be a valid string" in res["message"], res | |||||
| assert res["code"] == 0, res | |||||
| res = list_datasets(HttpApiAuth) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res | |||||
| @pytest.mark.p1 | @pytest.mark.p1 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( |
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, embedding_model", | "name, embedding_model", | ||||
| [ | [ | ||||
| ("empty", ""), | |||||
| ("space", " "), | |||||
| ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), | ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), | ||||
| ("missing_model_name", "@BAAI"), | ("missing_model_name", "@BAAI"), | ||||
| ("missing_provider", "BAAI/bge-large-zh-v1.5@"), | ("missing_provider", "BAAI/bge-large-zh-v1.5@"), | ||||
| ("whitespace_only_model_name", " @BAAI"), | ("whitespace_only_model_name", " @BAAI"), | ||||
| ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), | ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), | ||||
| ], | ], | ||||
| ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], | |||||
| ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], | |||||
| ) | ) | ||||
| def test_embedding_model_format(self, client, name, embedding_model): | def test_embedding_model_format(self, client, name, embedding_model): | ||||
| payload = {"name": name, "embedding_model": embedding_model} | payload = {"name": name, "embedding_model": embedding_model} | ||||
| with pytest.raises(Exception) as excinfo: | with pytest.raises(Exception) as excinfo: | ||||
| client.create_dataset(**payload) | client.create_dataset(**payload) | ||||
| if name == "missing_at": | |||||
| if name in ["empty", "space", "missing_at"]: | |||||
| assert "Embedding model identifier must follow <model_name>@<provider> format" in str(excinfo.value), str(excinfo.value) | assert "Embedding model identifier must follow <model_name>@<provider> format" in str(excinfo.value), str(excinfo.value) | ||||
| else: | else: | ||||
| assert "Both model_name and provider must be non-empty strings" in str(excinfo.value), str(excinfo.value) | assert "Both model_name and provider must be non-empty strings" in str(excinfo.value), str(excinfo.value) | ||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| def test_embedding_model_none(self, client): | def test_embedding_model_none(self, client): | ||||
| payload = {"name": "embedding_model_none", "embedding_model": None} | payload = {"name": "embedding_model_none", "embedding_model": None} | ||||
| with pytest.raises(Exception) as excinfo: | |||||
| client.create_dataset(**payload) | |||||
| assert "Input should be a valid string" in str(excinfo.value), str(excinfo.value) | |||||
| dataset = client.create_dataset(**payload) | |||||
| assert dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(dataset) | |||||
| @pytest.mark.p1 | @pytest.mark.p1 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( |
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, embedding_model", | "name, embedding_model", | ||||
| [ | [ | ||||
| ("empty", ""), | |||||
| ("space", " "), | |||||
| ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), | ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), | ||||
| ("missing_model_name", "@BAAI"), | ("missing_model_name", "@BAAI"), | ||||
| ("missing_provider", "BAAI/bge-large-zh-v1.5@"), | ("missing_provider", "BAAI/bge-large-zh-v1.5@"), | ||||
| ("whitespace_only_model_name", " @BAAI"), | ("whitespace_only_model_name", " @BAAI"), | ||||
| ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), | ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), | ||||
| ], | ], | ||||
| ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], | |||||
| ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], | |||||
| ) | ) | ||||
| def test_embedding_model_format(self, add_dataset_func, name, embedding_model): | def test_embedding_model_format(self, add_dataset_func, name, embedding_model): | ||||
| dataset = add_dataset_func | dataset = add_dataset_func | ||||
| with pytest.raises(Exception) as excinfo: | with pytest.raises(Exception) as excinfo: | ||||
| dataset.update({"name": name, "embedding_model": embedding_model}) | dataset.update({"name": name, "embedding_model": embedding_model}) | ||||
| error_msg = str(excinfo.value) | error_msg = str(excinfo.value) | ||||
| if name == "missing_at": | |||||
| if name in ["empty", "space", "missing_at"]: | |||||
| assert "Embedding model identifier must follow <model_name>@<provider> format" in error_msg, error_msg | assert "Embedding model identifier must follow <model_name>@<provider> format" in error_msg, error_msg | ||||
| else: | else: | ||||
| assert "Both model_name and provider must be non-empty strings" in error_msg, error_msg | assert "Both model_name and provider must be non-empty strings" in error_msg, error_msg | ||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| def test_embedding_model_none(self, add_dataset_func): | |||||
| def test_embedding_model_none(self, client, add_dataset_func): | |||||
| dataset = add_dataset_func | dataset = add_dataset_func | ||||
| with pytest.raises(Exception) as excinfo: | |||||
| dataset.update({"embedding_model": None}) | |||||
| assert "Input should be a valid string" in str(excinfo.value), str(excinfo.value) | |||||
| dataset.update({"embedding_model": None}) | |||||
| assert dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(dataset) | |||||
| retrieved_dataset = client.get_dataset(name=dataset.name) | |||||
| assert retrieved_dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(retrieved_dataset) | |||||
| @pytest.mark.p1 | @pytest.mark.p1 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( |