### What problem does this PR solve? - Remove pagerank from CreateDatasetReq and add to UpdateDatasetReq - Add pagerank update logic in dataset update endpoint - Update API documentation to reflect changes - Modify related test cases and SDK references #8208 This change makes pagerank a mutable property that can only be set after dataset creation, and only when using elasticsearch as the doc engine. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)tags/v0.19.1
| import logging | import logging | ||||
| import os | |||||
| from flask import request | from flask import request | ||||
| from peewee import OperationalError | from peewee import OperationalError | ||||
| from api import settings | |||||
| from api.db import FileSource, StatusEnum | from api.db import FileSource, StatusEnum | ||||
| from api.db.db_models import File | from api.db.db_models import File | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| validate_and_parse_json_request, | validate_and_parse_json_request, | ||||
| validate_and_parse_request_args, | validate_and_parse_request_args, | ||||
| ) | ) | ||||
| from rag.nlp import search | |||||
| from rag.settings import PAGERANK_FLD | |||||
| @manager.route("/datasets", methods=["POST"]) # noqa: F821 | @manager.route("/datasets", methods=["POST"]) # noqa: F821 | ||||
| "picture", "presentation", "qa", "table", "tag" | "picture", "presentation", "qa", "table", "tag" | ||||
| ] | ] | ||||
| description: Chunking method. | description: Chunking method. | ||||
| pagerank: | |||||
| type: integer | |||||
| description: Set page rank. | |||||
| parser_config: | parser_config: | ||||
| type: object | type: object | ||||
| description: Parser configuration. | description: Parser configuration. | ||||
| if not ok: | if not ok: | ||||
| return err | return err | ||||
| if "pagerank" in req and req["pagerank"] != kb.pagerank: | |||||
| if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity": | |||||
| return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") | |||||
| if req["pagerank"] > 0: | |||||
| settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) | |||||
| else: | |||||
| # Elasticsearch requires PAGERANK_FLD be non-zero! | |||||
| settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) | |||||
| if not KnowledgebaseService.update_by_id(kb.id, req): | if not KnowledgebaseService.update_by_id(kb.id, req): | ||||
| return get_error_data_result(message="Update dataset error.(Database error)") | return get_error_data_result(message="Update dataset error.(Database error)") | ||||
| embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] | embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] | ||||
| permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) | permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) | ||||
| chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") | chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") | ||||
| pagerank: int = Field(default=0, ge=0, le=100) | |||||
| parser_config: ParserConfig | None = Field(default=None) | parser_config: ParserConfig | None = Field(default=None) | ||||
| @field_validator("avatar") | @field_validator("avatar") | ||||
| class UpdateDatasetReq(CreateDatasetReq): | class UpdateDatasetReq(CreateDatasetReq): | ||||
| dataset_id: str = Field(...) | dataset_id: str = Field(...) | ||||
| name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] | name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] | ||||
| pagerank: int = Field(default=0, ge=0, le=100) | |||||
| @field_validator("dataset_id", mode="before") | @field_validator("dataset_id", mode="before") | ||||
| @classmethod | @classmethod |
| - `"embedding_model"`: `string` | - `"embedding_model"`: `string` | ||||
| - `"permission"`: `string` | - `"permission"`: `string` | ||||
| - `"chunk_method"`: `string` | - `"chunk_method"`: `string` | ||||
| - `"pagerank"`: `int` | |||||
| - `"parser_config"`: `object` | - `"parser_config"`: `object` | ||||
| ##### Request example | ##### Request example | ||||
| - `"me"`: (Default) Only you can manage the dataset. | - `"me"`: (Default) Only you can manage the dataset. | ||||
| - `"team"`: All team members can manage the dataset. | - `"team"`: All team members can manage the dataset. | ||||
| - `"pagerank"`: (*Body parameter*), `int` | |||||
| refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) | |||||
| - Default: `0` | |||||
| - Minimum: `0` | |||||
| - Maximum: `100` | |||||
| - `"chunk_method"`: (*Body parameter*), `enum<string>` | - `"chunk_method"`: (*Body parameter*), `enum<string>` | ||||
| The chunking method of the dataset to create. Available options: | The chunking method of the dataset to create. Available options: | ||||
| - `"naive"`: General (default) | - `"naive"`: General (default) |
| embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", | embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", | ||||
| permission: str = "me", | permission: str = "me", | ||||
| chunk_method: str = "naive", | chunk_method: str = "naive", | ||||
| pagerank: int = 0, | |||||
| parser_config: DataSet.ParserConfig = None | parser_config: DataSet.ParserConfig = None | ||||
| ) -> DataSet | ) -> DataSet | ||||
| ``` | ``` | ||||
| - `"one"`: One | - `"one"`: One | ||||
| - `"email"`: Email | - `"email"`: Email | ||||
| ##### pagerank, `int` | |||||
| The pagerank of the dataset to create. Defaults to `0`. | |||||
| ##### parser_config | ##### parser_config | ||||
| The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`: | The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`: |
| embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", | embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", | ||||
| permission: str = "me", | permission: str = "me", | ||||
| chunk_method: str = "naive", | chunk_method: str = "naive", | ||||
| pagerank: int = 0, | |||||
| parser_config: Optional[DataSet.ParserConfig] = None, | parser_config: Optional[DataSet.ParserConfig] = None, | ||||
| ) -> DataSet: | ) -> DataSet: | ||||
| payload = { | payload = { | ||||
| "embedding_model": embedding_model, | "embedding_model": embedding_model, | ||||
| "permission": permission, | "permission": permission, | ||||
| "chunk_method": chunk_method, | "chunk_method": chunk_method, | ||||
| "pagerank": pagerank, | |||||
| } | } | ||||
| if parser_config is not None: | if parser_config is not None: | ||||
| payload["parser_config"] = parser_config.to_json() | payload["parser_config"] = parser_config.to_json() |
| assert res["code"] == 101, res | assert res["code"] == 101, res | ||||
| assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res | assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res | ||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "name, pagerank", | |||||
| [ | |||||
| ("pagerank_min", 0), | |||||
| ("pagerank_mid", 50), | |||||
| ("pagerank_max", 100), | |||||
| ], | |||||
| ids=["min", "mid", "max"], | |||||
| ) | |||||
| def test_pagerank(self, HttpApiAuth, name, pagerank): | |||||
| payload = {"name": name, "pagerank": pagerank} | |||||
| res = create_dataset(HttpApiAuth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["pagerank"] == pagerank, res | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "name, pagerank, expected_message", | |||||
| [ | |||||
| ("pagerank_min_limit", -1, "Input should be greater than or equal to 0"), | |||||
| ("pagerank_max_limit", 101, "Input should be less than or equal to 100"), | |||||
| ], | |||||
| ids=["min_limit", "max_limit"], | |||||
| ) | |||||
| def test_pagerank_invalid(self, HttpApiAuth, name, pagerank, expected_message): | |||||
| payload = {"name": name, "pagerank": pagerank} | |||||
| res = create_dataset(HttpApiAuth, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert expected_message in res["message"], res | |||||
| @pytest.mark.p3 | |||||
| def test_pagerank_unset(self, HttpApiAuth): | |||||
| payload = {"name": "pagerank_unset"} | |||||
| res = create_dataset(HttpApiAuth, payload) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"]["pagerank"] == 0, res | |||||
| @pytest.mark.p3 | |||||
| def test_pagerank_none(self, HttpApiAuth): | |||||
| payload = {"name": "pagerank_unset", "pagerank": None} | |||||
| res = create_dataset(HttpApiAuth, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert "Input should be a valid integer" in res["message"], res | |||||
| @pytest.mark.p1 | @pytest.mark.p1 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, parser_config", | "name, parser_config", | ||||
| {"name": "chunk_count", "chunk_count": 1}, | {"name": "chunk_count", "chunk_count": 1}, | ||||
| {"name": "token_num", "token_num": 1}, | {"name": "token_num", "token_num": 1}, | ||||
| {"name": "status", "status": "1"}, | {"name": "status", "status": "1"}, | ||||
| {"name": "pagerank", "pagerank": 50}, | |||||
| {"name": "unknown_field", "unknown_field": "unknown_field"}, | {"name": "unknown_field", "unknown_field": "unknown_field"}, | ||||
| ], | ], | ||||
| ) | ) |
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import os | |||||
| import uuid | import uuid | ||||
| from concurrent.futures import ThreadPoolExecutor, as_completed | from concurrent.futures import ThreadPoolExecutor, as_completed | ||||
| import pytest | import pytest | ||||
| from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset | |||||
| from common import list_datasets, update_dataset | |||||
| from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN | |||||
| from hypothesis import HealthCheck, example, given, settings | from hypothesis import HealthCheck, example, given, settings | ||||
| from libs.auth import RAGFlowHttpApiAuth | from libs.auth import RAGFlowHttpApiAuth | ||||
| from utils import encode_avatar | from utils import encode_avatar | ||||
| @pytest.mark.p3 | @pytest.mark.p3 | ||||
| def test_name_duplicated(self, HttpApiAuth, add_datasets_func): | def test_name_duplicated(self, HttpApiAuth, add_datasets_func): | ||||
| dataset_ids = add_datasets_func[0] | |||||
| dataset_id = add_datasets_func[0] | |||||
| name = "dataset_1" | name = "dataset_1" | ||||
| payload = {"name": name} | payload = {"name": name} | ||||
| res = update_dataset(HttpApiAuth, dataset_ids, payload) | |||||
| res = update_dataset(HttpApiAuth, dataset_id, payload) | |||||
| assert res["code"] == 102, res | assert res["code"] == 102, res | ||||
| assert res["message"] == f"Dataset name '{name}' already exists", res | assert res["message"] == f"Dataset name '{name}' already exists", res | ||||
| assert res["code"] == 101, res | assert res["code"] == 101, res | ||||
| assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res | assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res | ||||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") | |||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) | @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) | ||||
| def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank): | def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank): | ||||
| assert res["code"] == 0, res | assert res["code"] == 0, res | ||||
| assert res["data"][0]["pagerank"] == pagerank | assert res["data"][0]["pagerank"] == pagerank | ||||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") | |||||
| @pytest.mark.p2 | |||||
| def test_pagerank_set_to_0(self, HttpApiAuth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"pagerank": 50} | |||||
| res = update_dataset(HttpApiAuth, dataset_id, payload) | |||||
| assert res["code"] == 0, res | |||||
| res = list_datasets(HttpApiAuth, {"id": dataset_id}) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["pagerank"] == 50, res | |||||
| payload = {"pagerank": 0} | |||||
| res = update_dataset(HttpApiAuth, dataset_id, payload) | |||||
| assert res["code"] == 0 | |||||
| res = list_datasets(HttpApiAuth, {"id": dataset_id}) | |||||
| assert res["code"] == 0, res | |||||
| assert res["data"][0]["pagerank"] == 0, res | |||||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208") | |||||
| @pytest.mark.p2 | |||||
| def test_pagerank_infinity(self, HttpApiAuth, add_dataset_func): | |||||
| dataset_id = add_dataset_func | |||||
| payload = {"pagerank": 50} | |||||
| res = update_dataset(HttpApiAuth, dataset_id, payload) | |||||
| assert res["code"] == 101, res | |||||
| assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res | |||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "pagerank, expected_message", | "pagerank, expected_message", |
| client.create_dataset(**payload) | client.create_dataset(**payload) | ||||
| assert "not instance of" in str(excinfo.value), str(excinfo.value) | assert "not instance of" in str(excinfo.value), str(excinfo.value) | ||||
| @pytest.mark.p2 | |||||
| @pytest.mark.parametrize( | |||||
| "name, pagerank", | |||||
| [ | |||||
| ("pagerank_min", 0), | |||||
| ("pagerank_mid", 50), | |||||
| ("pagerank_max", 100), | |||||
| ], | |||||
| ids=["min", "mid", "max"], | |||||
| ) | |||||
| def test_pagerank(self, client, name, pagerank): | |||||
| payload = {"name": name, "pagerank": pagerank} | |||||
| dataset = client.create_dataset(**payload) | |||||
| assert dataset.pagerank == pagerank, str(dataset) | |||||
| @pytest.mark.p3 | |||||
| @pytest.mark.parametrize( | |||||
| "name, pagerank, expected_message", | |||||
| [ | |||||
| ("pagerank_min_limit", -1, "Input should be greater than or equal to 0"), | |||||
| ("pagerank_max_limit", 101, "Input should be less than or equal to 100"), | |||||
| ], | |||||
| ids=["min_limit", "max_limit"], | |||||
| ) | |||||
| def test_pagerank_invalid(self, client, name, pagerank, expected_message): | |||||
| payload = {"name": name, "pagerank": pagerank} | |||||
| with pytest.raises(Exception) as excinfo: | |||||
| client.create_dataset(**payload) | |||||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||||
| @pytest.mark.p3 | |||||
| def test_pagerank_unset(self, client): | |||||
| payload = {"name": "pagerank_unset"} | |||||
| dataset = client.create_dataset(**payload) | |||||
| assert dataset.pagerank == 0, str(dataset) | |||||
| @pytest.mark.p3 | |||||
| def test_pagerank_none(self, client): | |||||
| payload = {"name": "pagerank_unset", "pagerank": None} | |||||
| with pytest.raises(Exception) as excinfo: | |||||
| client.create_dataset(**payload) | |||||
| assert "not instance of" in str(excinfo.value), str(excinfo.value) | |||||
| @pytest.mark.p1 | @pytest.mark.p1 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "name, parser_config", | "name, parser_config", | ||||
| {"name": "chunk_count", "chunk_count": 1}, | {"name": "chunk_count", "chunk_count": 1}, | ||||
| {"name": "token_num", "token_num": 1}, | {"name": "token_num", "token_num": 1}, | ||||
| {"name": "status", "status": "1"}, | {"name": "status", "status": "1"}, | ||||
| {"name": "pagerank", "pagerank": 50}, | |||||
| {"name": "unknown_field", "unknown_field": "unknown_field"}, | {"name": "unknown_field", "unknown_field": "unknown_field"}, | ||||
| ], | ], | ||||
| ) | ) |
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import os | |||||
| from concurrent.futures import ThreadPoolExecutor, as_completed | from concurrent.futures import ThreadPoolExecutor, as_completed | ||||
| from operator import attrgetter | from operator import attrgetter | ||||
| dataset.update({"chunk_method": None}) | dataset.update({"chunk_method": None}) | ||||
| assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value) | assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value) | ||||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") | |||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) | @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) | ||||
| def test_pagerank(self, client, add_dataset_func, pagerank): | def test_pagerank(self, client, add_dataset_func, pagerank): | ||||
| retrieved_dataset = client.get_dataset(name=dataset.name) | retrieved_dataset = client.get_dataset(name=dataset.name) | ||||
| assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset) | assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset) | ||||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") | |||||
| @pytest.mark.p2 | |||||
| def test_pagerank_set_to_0(self, client, add_dataset_func): | |||||
| dataset = add_dataset_func | |||||
| dataset.update({"pagerank": 50}) | |||||
| assert dataset.pagerank == 50, str(dataset) | |||||
| retrieved_dataset = client.get_dataset(name=dataset.name) | |||||
| assert retrieved_dataset.pagerank == 50, str(retrieved_dataset) | |||||
| dataset.update({"pagerank": 0}) | |||||
| assert dataset.pagerank == 0, str(dataset) | |||||
| retrieved_dataset = client.get_dataset(name=dataset.name) | |||||
| assert retrieved_dataset.pagerank == 0, str(retrieved_dataset) | |||||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208") | |||||
| @pytest.mark.p2 | |||||
| def test_pagerank_infinity(self, client, add_dataset_func): | |||||
| dataset = add_dataset_func | |||||
| with pytest.raises(Exception) as excinfo: | |||||
| dataset.update({"pagerank": 50}) | |||||
| assert "'pagerank' can only be set when doc_engine is elasticsearch" in str(excinfo.value), str(excinfo.value) | |||||
| @pytest.mark.p2 | @pytest.mark.p2 | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "pagerank, expected_message", | "pagerank, expected_message", |