### What problem does this PR solve? - Remove pagerank from CreateDatasetReq and add to UpdateDatasetReq - Add pagerank update logic in dataset update endpoint - Update API documentation to reflect changes - Modify related test cases and SDK references #8208 This change makes pagerank a mutable property that can only be set after dataset creation, and only when using elasticsearch as the doc engine. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)tags/v0.19.1
| @@ -16,10 +16,12 @@ | |||
| import logging | |||
| import os | |||
| from flask import request | |||
| from peewee import OperationalError | |||
| from api import settings | |||
| from api.db import FileSource, StatusEnum | |||
| from api.db.db_models import File | |||
| from api.db.services.document_service import DocumentService | |||
| @@ -48,6 +50,8 @@ from api.utils.validation_utils import ( | |||
| validate_and_parse_json_request, | |||
| validate_and_parse_request_args, | |||
| ) | |||
| from rag.nlp import search | |||
| from rag.settings import PAGERANK_FLD | |||
| @manager.route("/datasets", methods=["POST"]) # noqa: F821 | |||
| @@ -97,9 +101,6 @@ def create(tenant_id): | |||
| "picture", "presentation", "qa", "table", "tag" | |||
| ] | |||
| description: Chunking method. | |||
| pagerank: | |||
| type: integer | |||
| description: Set page rank. | |||
| parser_config: | |||
| type: object | |||
| description: Parser configuration. | |||
| @@ -352,6 +353,16 @@ def update(tenant_id, dataset_id): | |||
| if not ok: | |||
| return err | |||
| if "pagerank" in req and req["pagerank"] != kb.pagerank: | |||
| if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity": | |||
| return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") | |||
| if req["pagerank"] > 0: | |||
| settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) | |||
| else: | |||
| # Elasticsearch requires PAGERANK_FLD be non-zero! | |||
| settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) | |||
| if not KnowledgebaseService.update_by_id(kb.id, req): | |||
| return get_error_data_result(message="Update dataset error.(Database error)") | |||
| @@ -383,7 +383,6 @@ class CreateDatasetReq(Base): | |||
| embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] | |||
| permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) | |||
| chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") | |||
| pagerank: int = Field(default=0, ge=0, le=100) | |||
| parser_config: ParserConfig | None = Field(default=None) | |||
| @field_validator("avatar") | |||
| @@ -539,6 +538,7 @@ class CreateDatasetReq(Base): | |||
| class UpdateDatasetReq(CreateDatasetReq): | |||
| dataset_id: str = Field(...) | |||
| name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] | |||
| pagerank: int = Field(default=0, ge=0, le=100) | |||
| @field_validator("dataset_id", mode="before") | |||
| @classmethod | |||
| @@ -343,7 +343,6 @@ Creates a dataset. | |||
| - `"embedding_model"`: `string` | |||
| - `"permission"`: `string` | |||
| - `"chunk_method"`: `string` | |||
| - `"pagerank"`: `int` | |||
| - `"parser_config"`: `object` | |||
| ##### Request example | |||
| @@ -384,12 +383,6 @@ curl --request POST \ | |||
| - `"me"`: (Default) Only you can manage the dataset. | |||
| - `"team"`: All team members can manage the dataset. | |||
| - `"pagerank"`: (*Body parameter*), `int` | |||
| refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) | |||
| - Default: `0` | |||
| - Minimum: `0` | |||
| - Maximum: `100` | |||
| - `"chunk_method"`: (*Body parameter*), `enum<string>` | |||
| The chunking method of the dataset to create. Available options: | |||
| - `"naive"`: General (default) | |||
| @@ -100,7 +100,6 @@ RAGFlow.create_dataset( | |||
| embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", | |||
| permission: str = "me", | |||
| chunk_method: str = "naive", | |||
| pagerank: int = 0, | |||
| parser_config: DataSet.ParserConfig = None | |||
| ) -> DataSet | |||
| ``` | |||
| @@ -148,10 +147,6 @@ The chunking method of the dataset to create. Available options: | |||
| - `"one"`: One | |||
| - `"email"`: Email | |||
| ##### pagerank, `int` | |||
| The pagerank of the dataset to create. Defaults to `0`. | |||
| ##### parser_config | |||
| The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`: | |||
| @@ -56,7 +56,6 @@ class RAGFlow: | |||
| embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", | |||
| permission: str = "me", | |||
| chunk_method: str = "naive", | |||
| pagerank: int = 0, | |||
| parser_config: Optional[DataSet.ParserConfig] = None, | |||
| ) -> DataSet: | |||
| payload = { | |||
| @@ -66,7 +65,6 @@ class RAGFlow: | |||
| "embedding_model": embedding_model, | |||
| "permission": permission, | |||
| "chunk_method": chunk_method, | |||
| "pagerank": pagerank, | |||
| } | |||
| if parser_config is not None: | |||
| payload["parser_config"] = parser_config.to_json() | |||
| @@ -394,51 +394,6 @@ class TestDatasetCreate: | |||
| assert res["code"] == 101, res | |||
| assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "name, pagerank", | |||
| [ | |||
| ("pagerank_min", 0), | |||
| ("pagerank_mid", 50), | |||
| ("pagerank_max", 100), | |||
| ], | |||
| ids=["min", "mid", "max"], | |||
| ) | |||
| def test_pagerank(self, HttpApiAuth, name, pagerank): | |||
| payload = {"name": name, "pagerank": pagerank} | |||
| res = create_dataset(HttpApiAuth, payload) | |||
| assert res["code"] == 0, res | |||
| assert res["data"]["pagerank"] == pagerank, res | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "name, pagerank, expected_message", | |||
| [ | |||
| ("pagerank_min_limit", -1, "Input should be greater than or equal to 0"), | |||
| ("pagerank_max_limit", 101, "Input should be less than or equal to 100"), | |||
| ], | |||
| ids=["min_limit", "max_limit"], | |||
| ) | |||
| def test_pagerank_invalid(self, HttpApiAuth, name, pagerank, expected_message): | |||
| payload = {"name": name, "pagerank": pagerank} | |||
| res = create_dataset(HttpApiAuth, payload) | |||
| assert res["code"] == 101, res | |||
| assert expected_message in res["message"], res | |||
| @pytest.mark.p3 | |||
| def test_pagerank_unset(self, HttpApiAuth): | |||
| payload = {"name": "pagerank_unset"} | |||
| res = create_dataset(HttpApiAuth, payload) | |||
| assert res["code"] == 0, res | |||
| assert res["data"]["pagerank"] == 0, res | |||
| @pytest.mark.p3 | |||
| def test_pagerank_none(self, HttpApiAuth): | |||
| payload = {"name": "pagerank_unset", "pagerank": None} | |||
| res = create_dataset(HttpApiAuth, payload) | |||
| assert res["code"] == 101, res | |||
| assert "Input should be a valid integer" in res["message"], res | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "name, parser_config", | |||
| @@ -730,6 +685,7 @@ class TestDatasetCreate: | |||
| {"name": "chunk_count", "chunk_count": 1}, | |||
| {"name": "token_num", "token_num": 1}, | |||
| {"name": "status", "status": "1"}, | |||
| {"name": "pagerank", "pagerank": 50}, | |||
| {"name": "unknown_field", "unknown_field": "unknown_field"}, | |||
| ], | |||
| ) | |||
| @@ -13,11 +13,13 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| import uuid | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset | |||
| from common import list_datasets, update_dataset | |||
| from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN | |||
| from hypothesis import HealthCheck, example, given, settings | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| from utils import encode_avatar | |||
| @@ -155,10 +157,10 @@ class TestDatasetUpdate: | |||
| @pytest.mark.p3 | |||
| def test_name_duplicated(self, HttpApiAuth, add_datasets_func): | |||
| dataset_ids = add_datasets_func[0] | |||
| dataset_id = add_datasets_func[0] | |||
| name = "dataset_1" | |||
| payload = {"name": name} | |||
| res = update_dataset(HttpApiAuth, dataset_ids, payload) | |||
| res = update_dataset(HttpApiAuth, dataset_id, payload) | |||
| assert res["code"] == 102, res | |||
| assert res["message"] == f"Dataset name '{name}' already exists", res | |||
| @@ -425,6 +427,7 @@ class TestDatasetUpdate: | |||
| assert res["code"] == 101, res | |||
| assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res | |||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) | |||
| def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank): | |||
| @@ -437,6 +440,35 @@ class TestDatasetUpdate: | |||
| assert res["code"] == 0, res | |||
| assert res["data"][0]["pagerank"] == pagerank | |||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") | |||
| @pytest.mark.p2 | |||
| def test_pagerank_set_to_0(self, HttpApiAuth, add_dataset_func): | |||
| dataset_id = add_dataset_func | |||
| payload = {"pagerank": 50} | |||
| res = update_dataset(HttpApiAuth, dataset_id, payload) | |||
| assert res["code"] == 0, res | |||
| res = list_datasets(HttpApiAuth, {"id": dataset_id}) | |||
| assert res["code"] == 0, res | |||
| assert res["data"][0]["pagerank"] == 50, res | |||
| payload = {"pagerank": 0} | |||
| res = update_dataset(HttpApiAuth, dataset_id, payload) | |||
| assert res["code"] == 0 | |||
| res = list_datasets(HttpApiAuth, {"id": dataset_id}) | |||
| assert res["code"] == 0, res | |||
| assert res["data"][0]["pagerank"] == 0, res | |||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208") | |||
| @pytest.mark.p2 | |||
| def test_pagerank_infinity(self, HttpApiAuth, add_dataset_func): | |||
| dataset_id = add_dataset_func | |||
| payload = {"pagerank": 50} | |||
| res = update_dataset(HttpApiAuth, dataset_id, payload) | |||
| assert res["code"] == 101, res | |||
| assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "pagerank, expected_message", | |||
| @@ -344,49 +344,6 @@ class TestDatasetCreate: | |||
| client.create_dataset(**payload) | |||
| assert "not instance of" in str(excinfo.value), str(excinfo.value) | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "name, pagerank", | |||
| [ | |||
| ("pagerank_min", 0), | |||
| ("pagerank_mid", 50), | |||
| ("pagerank_max", 100), | |||
| ], | |||
| ids=["min", "mid", "max"], | |||
| ) | |||
| def test_pagerank(self, client, name, pagerank): | |||
| payload = {"name": name, "pagerank": pagerank} | |||
| dataset = client.create_dataset(**payload) | |||
| assert dataset.pagerank == pagerank, str(dataset) | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "name, pagerank, expected_message", | |||
| [ | |||
| ("pagerank_min_limit", -1, "Input should be greater than or equal to 0"), | |||
| ("pagerank_max_limit", 101, "Input should be less than or equal to 100"), | |||
| ], | |||
| ids=["min_limit", "max_limit"], | |||
| ) | |||
| def test_pagerank_invalid(self, client, name, pagerank, expected_message): | |||
| payload = {"name": name, "pagerank": pagerank} | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.create_dataset(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| @pytest.mark.p3 | |||
| def test_pagerank_unset(self, client): | |||
| payload = {"name": "pagerank_unset"} | |||
| dataset = client.create_dataset(**payload) | |||
| assert dataset.pagerank == 0, str(dataset) | |||
| @pytest.mark.p3 | |||
| def test_pagerank_none(self, client): | |||
| payload = {"name": "pagerank_unset", "pagerank": None} | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.create_dataset(**payload) | |||
| assert "not instance of" in str(excinfo.value), str(excinfo.value) | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "name, parser_config", | |||
| @@ -689,6 +646,7 @@ class TestDatasetCreate: | |||
| {"name": "chunk_count", "chunk_count": 1}, | |||
| {"name": "token_num", "token_num": 1}, | |||
| {"name": "status", "status": "1"}, | |||
| {"name": "pagerank", "pagerank": 50}, | |||
| {"name": "unknown_field", "unknown_field": "unknown_field"}, | |||
| ], | |||
| ) | |||
| @@ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| from operator import attrgetter | |||
| @@ -324,6 +325,7 @@ class TestDatasetUpdate: | |||
| dataset.update({"chunk_method": None}) | |||
| assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value) | |||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) | |||
| def test_pagerank(self, client, add_dataset_func, pagerank): | |||
| @@ -334,6 +336,30 @@ class TestDatasetUpdate: | |||
| retrieved_dataset = client.get_dataset(name=dataset.name) | |||
| assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset) | |||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") | |||
| @pytest.mark.p2 | |||
| def test_pagerank_set_to_0(self, client, add_dataset_func): | |||
| dataset = add_dataset_func | |||
| dataset.update({"pagerank": 50}) | |||
| assert dataset.pagerank == 50, str(dataset) | |||
| retrieved_dataset = client.get_dataset(name=dataset.name) | |||
| assert retrieved_dataset.pagerank == 50, str(retrieved_dataset) | |||
| dataset.update({"pagerank": 0}) | |||
| assert dataset.pagerank == 0, str(dataset) | |||
| retrieved_dataset = client.get_dataset(name=dataset.name) | |||
| assert retrieved_dataset.pagerank == 0, str(retrieved_dataset) | |||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208") | |||
| @pytest.mark.p2 | |||
| def test_pagerank_infinity(self, client, add_dataset_func): | |||
| dataset = add_dataset_func | |||
| with pytest.raises(Exception) as excinfo: | |||
| dataset.update({"pagerank": 50}) | |||
| assert "'pagerank' can only be set when doc_engine is elasticsearch" in str(excinfo.value), str(excinfo.value) | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "pagerank, expected_message", | |||