Sfoglia il codice sorgente

Fix: Move pagerank field from create to update dataset API (#8217)

### What problem does this PR solve?

- Remove pagerank from CreateDatasetReq and add to UpdateDatasetReq
- Add pagerank update logic in dataset update endpoint
- Update API documentation to reflect changes
- Modify related test cases and SDK references

#8208

This change makes pagerank a mutable property that can only be set after
dataset creation, and only when using elasticsearch as the doc engine.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.19.1
Liu An 4 mesi fa
parent
commit
7fbbc9650d
Nessun account collegato all'indirizzo email del committer

+ 14
- 3
api/apps/sdk/dataset.py Vedi File





import logging import logging
import os


from flask import request from flask import request
from peewee import OperationalError from peewee import OperationalError


from api import settings
from api.db import FileSource, StatusEnum from api.db import FileSource, StatusEnum
from api.db.db_models import File from api.db.db_models import File
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
validate_and_parse_json_request, validate_and_parse_json_request,
validate_and_parse_request_args, validate_and_parse_request_args,
) )
from rag.nlp import search
from rag.settings import PAGERANK_FLD




@manager.route("/datasets", methods=["POST"]) # noqa: F821 @manager.route("/datasets", methods=["POST"]) # noqa: F821
"picture", "presentation", "qa", "table", "tag" "picture", "presentation", "qa", "table", "tag"
] ]
description: Chunking method. description: Chunking method.
pagerank:
type: integer
description: Set page rank.
parser_config: parser_config:
type: object type: object
description: Parser configuration. description: Parser configuration.
if not ok: if not ok:
return err return err


if "pagerank" in req and req["pagerank"] != kb.pagerank:
if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity":
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")

if req["pagerank"] > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)

if not KnowledgebaseService.update_by_id(kb.id, req): if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)") return get_error_data_result(message="Update dataset error.(Database error)")



+ 1
- 1
api/utils/validation_utils.py Vedi File

embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
pagerank: int = Field(default=0, ge=0, le=100)
parser_config: ParserConfig | None = Field(default=None) parser_config: ParserConfig | None = Field(default=None)


@field_validator("avatar") @field_validator("avatar")
class UpdateDatasetReq(CreateDatasetReq): class UpdateDatasetReq(CreateDatasetReq):
dataset_id: str = Field(...) dataset_id: str = Field(...)
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
pagerank: int = Field(default=0, ge=0, le=100)


@field_validator("dataset_id", mode="before") @field_validator("dataset_id", mode="before")
@classmethod @classmethod

+ 0
- 7
docs/references/http_api_reference.md Vedi File

- `"embedding_model"`: `string` - `"embedding_model"`: `string`
- `"permission"`: `string` - `"permission"`: `string`
- `"chunk_method"`: `string` - `"chunk_method"`: `string`
- `"pagerank"`: `int`
- `"parser_config"`: `object` - `"parser_config"`: `object`


##### Request example ##### Request example
- `"me"`: (Default) Only you can manage the dataset. - `"me"`: (Default) Only you can manage the dataset.
- `"team"`: All team members can manage the dataset. - `"team"`: All team members can manage the dataset.


- `"pagerank"`: (*Body parameter*), `int`
refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
- Default: `0`
- Minimum: `0`
- Maximum: `100`

- `"chunk_method"`: (*Body parameter*), `enum<string>` - `"chunk_method"`: (*Body parameter*), `enum<string>`
The chunking method of the dataset to create. Available options: The chunking method of the dataset to create. Available options:
- `"naive"`: General (default) - `"naive"`: General (default)

+ 0
- 5
docs/references/python_api_reference.md Vedi File

embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
permission: str = "me", permission: str = "me",
chunk_method: str = "naive", chunk_method: str = "naive",
pagerank: int = 0,
parser_config: DataSet.ParserConfig = None parser_config: DataSet.ParserConfig = None
) -> DataSet ) -> DataSet
``` ```
- `"one"`: One - `"one"`: One
- `"email"`: Email - `"email"`: Email


##### pagerank, `int`

The pagerank of the dataset to create. Defaults to `0`.

##### parser_config ##### parser_config


The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`: The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`:

+ 0
- 2
sdk/python/ragflow_sdk/ragflow.py Vedi File

embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
permission: str = "me", permission: str = "me",
chunk_method: str = "naive", chunk_method: str = "naive",
pagerank: int = 0,
parser_config: Optional[DataSet.ParserConfig] = None, parser_config: Optional[DataSet.ParserConfig] = None,
) -> DataSet: ) -> DataSet:
payload = { payload = {
"embedding_model": embedding_model, "embedding_model": embedding_model,
"permission": permission, "permission": permission,
"chunk_method": chunk_method, "chunk_method": chunk_method,
"pagerank": pagerank,
} }
if parser_config is not None: if parser_config is not None:
payload["parser_config"] = parser_config.to_json() payload["parser_config"] = parser_config.to_json()

+ 1
- 45
test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py Vedi File

assert res["code"] == 101, res assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res


@pytest.mark.p2
@pytest.mark.parametrize(
"name, pagerank",
[
("pagerank_min", 0),
("pagerank_mid", 50),
("pagerank_max", 100),
],
ids=["min", "mid", "max"],
)
def test_pagerank(self, HttpApiAuth, name, pagerank):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == pagerank, res

@pytest.mark.p3
@pytest.mark.parametrize(
"name, pagerank, expected_message",
[
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, HttpApiAuth, name, pagerank, expected_message):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res

@pytest.mark.p3
def test_pagerank_unset(self, HttpApiAuth):
payload = {"name": "pagerank_unset"}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == 0, res

@pytest.mark.p3
def test_pagerank_none(self, HttpApiAuth):
payload = {"name": "pagerank_unset", "pagerank": None}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 101, res
assert "Input should be a valid integer" in res["message"], res

@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, parser_config", "name, parser_config",
{"name": "chunk_count", "chunk_count": 1}, {"name": "chunk_count", "chunk_count": 1},
{"name": "token_num", "token_num": 1}, {"name": "token_num", "token_num": 1},
{"name": "status", "status": "1"}, {"name": "status", "status": "1"},
{"name": "pagerank", "pagerank": 50},
{"name": "unknown_field", "unknown_field": "unknown_field"}, {"name": "unknown_field", "unknown_field": "unknown_field"},
], ],
) )

+ 35
- 3
test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py Vedi File

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed


import pytest import pytest
from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset
from common import list_datasets, update_dataset
from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN
from hypothesis import HealthCheck, example, given, settings from hypothesis import HealthCheck, example, given, settings
from libs.auth import RAGFlowHttpApiAuth from libs.auth import RAGFlowHttpApiAuth
from utils import encode_avatar from utils import encode_avatar


@pytest.mark.p3 @pytest.mark.p3
def test_name_duplicated(self, HttpApiAuth, add_datasets_func): def test_name_duplicated(self, HttpApiAuth, add_datasets_func):
dataset_ids = add_datasets_func[0]
dataset_id = add_datasets_func[0]
name = "dataset_1" name = "dataset_1"
payload = {"name": name} payload = {"name": name}
res = update_dataset(HttpApiAuth, dataset_ids, payload)
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 102, res assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name}' already exists", res assert res["message"] == f"Dataset name '{name}' already exists", res


assert res["code"] == 101, res assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res


@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank): def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank):
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == pagerank assert res["data"][0]["pagerank"] == pagerank


@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_set_to_0(self, HttpApiAuth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": 50}
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 0, res

res = list_datasets(HttpApiAuth, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == 50, res

payload = {"pagerank": 0}
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 0

res = list_datasets(HttpApiAuth, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == 0, res

@pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_infinity(self, HttpApiAuth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": 50}
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 101, res
assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res

@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pagerank, expected_message", "pagerank, expected_message",

+ 1
- 43
test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py Vedi File

client.create_dataset(**payload) client.create_dataset(**payload)
assert "not instance of" in str(excinfo.value), str(excinfo.value) assert "not instance of" in str(excinfo.value), str(excinfo.value)


@pytest.mark.p2
@pytest.mark.parametrize(
"name, pagerank",
[
("pagerank_min", 0),
("pagerank_mid", 50),
("pagerank_max", 100),
],
ids=["min", "mid", "max"],
)
def test_pagerank(self, client, name, pagerank):
payload = {"name": name, "pagerank": pagerank}
dataset = client.create_dataset(**payload)
assert dataset.pagerank == pagerank, str(dataset)

@pytest.mark.p3
@pytest.mark.parametrize(
"name, pagerank, expected_message",
[
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, client, name, pagerank, expected_message):
payload = {"name": name, "pagerank": pagerank}
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
assert expected_message in str(excinfo.value), str(excinfo.value)

@pytest.mark.p3
def test_pagerank_unset(self, client):
payload = {"name": "pagerank_unset"}
dataset = client.create_dataset(**payload)
assert dataset.pagerank == 0, str(dataset)

@pytest.mark.p3
def test_pagerank_none(self, client):
payload = {"name": "pagerank_unset", "pagerank": None}
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
assert "not instance of" in str(excinfo.value), str(excinfo.value)

@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, parser_config", "name, parser_config",
{"name": "chunk_count", "chunk_count": 1}, {"name": "chunk_count", "chunk_count": 1},
{"name": "token_num", "token_num": 1}, {"name": "token_num", "token_num": 1},
{"name": "status", "status": "1"}, {"name": "status", "status": "1"},
{"name": "pagerank", "pagerank": 50},
{"name": "unknown_field", "unknown_field": "unknown_field"}, {"name": "unknown_field", "unknown_field": "unknown_field"},
], ],
) )

+ 26
- 0
test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py Vedi File

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from operator import attrgetter from operator import attrgetter


dataset.update({"chunk_method": None}) dataset.update({"chunk_method": None})
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value) assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value)


@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
def test_pagerank(self, client, add_dataset_func, pagerank): def test_pagerank(self, client, add_dataset_func, pagerank):
retrieved_dataset = client.get_dataset(name=dataset.name) retrieved_dataset = client.get_dataset(name=dataset.name)
assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset) assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset)


@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_set_to_0(self, client, add_dataset_func):
dataset = add_dataset_func
dataset.update({"pagerank": 50})
assert dataset.pagerank == 50, str(dataset)

retrieved_dataset = client.get_dataset(name=dataset.name)
assert retrieved_dataset.pagerank == 50, str(retrieved_dataset)

dataset.update({"pagerank": 0})
assert dataset.pagerank == 0, str(dataset)

retrieved_dataset = client.get_dataset(name=dataset.name)
assert retrieved_dataset.pagerank == 0, str(retrieved_dataset)

@pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_infinity(self, client, add_dataset_func):
dataset = add_dataset_func
with pytest.raises(Exception) as excinfo:
dataset.update({"pagerank": 50})
assert "'pagerank' can only be set when doc_engine is elasticsearch" in str(excinfo.value), str(excinfo.value)

@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pagerank, expected_message", "pagerank, expected_message",

Loading…
Annulla
Salva