Ver código fonte

Fix: Move pagerank field from create to update dataset API (#8217)

### What problem does this PR solve?

- Remove pagerank from CreateDatasetReq and add to UpdateDatasetReq
- Add pagerank update logic in dataset update endpoint
- Update API documentation to reflect changes
- Modify related test cases and SDK references

#8208

This change makes pagerank a mutable property that can only be set after
dataset creation, and only when using elasticsearch as the doc engine.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.19.1
Liu An 4 meses atrás
pai
commit
7fbbc9650d
Nenhuma conta vinculada ao e-mail do autor do commit

+ 14
- 3
api/apps/sdk/dataset.py Ver arquivo

@@ -16,10 +16,12 @@


import logging
import os

from flask import request
from peewee import OperationalError

from api import settings
from api.db import FileSource, StatusEnum
from api.db.db_models import File
from api.db.services.document_service import DocumentService
@@ -48,6 +50,8 @@ from api.utils.validation_utils import (
validate_and_parse_json_request,
validate_and_parse_request_args,
)
from rag.nlp import search
from rag.settings import PAGERANK_FLD


@manager.route("/datasets", methods=["POST"]) # noqa: F821
@@ -97,9 +101,6 @@ def create(tenant_id):
"picture", "presentation", "qa", "table", "tag"
]
description: Chunking method.
pagerank:
type: integer
description: Set page rank.
parser_config:
type: object
description: Parser configuration.
@@ -352,6 +353,16 @@ def update(tenant_id, dataset_id):
if not ok:
return err

if "pagerank" in req and req["pagerank"] != kb.pagerank:
if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity":
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")

if req["pagerank"] > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)

if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)")


+ 1
- 1
api/utils/validation_utils.py Ver arquivo

@@ -383,7 +383,6 @@ class CreateDatasetReq(Base):
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
pagerank: int = Field(default=0, ge=0, le=100)
parser_config: ParserConfig | None = Field(default=None)

@field_validator("avatar")
@@ -539,6 +538,7 @@ class CreateDatasetReq(Base):
class UpdateDatasetReq(CreateDatasetReq):
dataset_id: str = Field(...)
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
pagerank: int = Field(default=0, ge=0, le=100)

@field_validator("dataset_id", mode="before")
@classmethod

+ 0
- 7
docs/references/http_api_reference.md Ver arquivo

@@ -343,7 +343,6 @@ Creates a dataset.
- `"embedding_model"`: `string`
- `"permission"`: `string`
- `"chunk_method"`: `string`
- `"pagerank"`: `int`
- `"parser_config"`: `object`

##### Request example
@@ -384,12 +383,6 @@ curl --request POST \
- `"me"`: (Default) Only you can manage the dataset.
- `"team"`: All team members can manage the dataset.

- `"pagerank"`: (*Body parameter*), `int`
refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
- Default: `0`
- Minimum: `0`
- Maximum: `100`

- `"chunk_method"`: (*Body parameter*), `enum<string>`
The chunking method of the dataset to create. Available options:
- `"naive"`: General (default)

+ 0
- 5
docs/references/python_api_reference.md Ver arquivo

@@ -100,7 +100,6 @@ RAGFlow.create_dataset(
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
permission: str = "me",
chunk_method: str = "naive",
pagerank: int = 0,
parser_config: DataSet.ParserConfig = None
) -> DataSet
```
@@ -148,10 +147,6 @@ The chunking method of the dataset to create. Available options:
- `"one"`: One
- `"email"`: Email

##### pagerank, `int`

The pagerank of the dataset to create. Defaults to `0`.

##### parser_config

The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`:

+ 0
- 2
sdk/python/ragflow_sdk/ragflow.py Ver arquivo

@@ -56,7 +56,6 @@ class RAGFlow:
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
permission: str = "me",
chunk_method: str = "naive",
pagerank: int = 0,
parser_config: Optional[DataSet.ParserConfig] = None,
) -> DataSet:
payload = {
@@ -66,7 +65,6 @@ class RAGFlow:
"embedding_model": embedding_model,
"permission": permission,
"chunk_method": chunk_method,
"pagerank": pagerank,
}
if parser_config is not None:
payload["parser_config"] = parser_config.to_json()

+ 1
- 45
test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py Ver arquivo

@@ -394,51 +394,6 @@ class TestDatasetCreate:
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res

@pytest.mark.p2
@pytest.mark.parametrize(
"name, pagerank",
[
("pagerank_min", 0),
("pagerank_mid", 50),
("pagerank_max", 100),
],
ids=["min", "mid", "max"],
)
def test_pagerank(self, HttpApiAuth, name, pagerank):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == pagerank, res

@pytest.mark.p3
@pytest.mark.parametrize(
"name, pagerank, expected_message",
[
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, HttpApiAuth, name, pagerank, expected_message):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res

@pytest.mark.p3
def test_pagerank_unset(self, HttpApiAuth):
payload = {"name": "pagerank_unset"}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == 0, res

@pytest.mark.p3
def test_pagerank_none(self, HttpApiAuth):
payload = {"name": "pagerank_unset", "pagerank": None}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 101, res
assert "Input should be a valid integer" in res["message"], res

@pytest.mark.p1
@pytest.mark.parametrize(
"name, parser_config",
@@ -730,6 +685,7 @@ class TestDatasetCreate:
{"name": "chunk_count", "chunk_count": 1},
{"name": "token_num", "token_num": 1},
{"name": "status", "status": "1"},
{"name": "pagerank", "pagerank": 50},
{"name": "unknown_field", "unknown_field": "unknown_field"},
],
)

+ 35
- 3
test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py Ver arquivo

@@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset
from common import list_datasets, update_dataset
from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN
from hypothesis import HealthCheck, example, given, settings
from libs.auth import RAGFlowHttpApiAuth
from utils import encode_avatar
@@ -155,10 +157,10 @@ class TestDatasetUpdate:

@pytest.mark.p3
def test_name_duplicated(self, HttpApiAuth, add_datasets_func):
dataset_ids = add_datasets_func[0]
dataset_id = add_datasets_func[0]
name = "dataset_1"
payload = {"name": name}
res = update_dataset(HttpApiAuth, dataset_ids, payload)
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name}' already exists", res

@@ -425,6 +427,7 @@ class TestDatasetUpdate:
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res

@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank):
@@ -437,6 +440,35 @@ class TestDatasetUpdate:
assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == pagerank

@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_set_to_0(self, HttpApiAuth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": 50}
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 0, res

res = list_datasets(HttpApiAuth, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == 50, res

payload = {"pagerank": 0}
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 0

res = list_datasets(HttpApiAuth, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == 0, res

@pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_infinity(self, HttpApiAuth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": 50}
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 101, res
assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res

@pytest.mark.p2
@pytest.mark.parametrize(
"pagerank, expected_message",

+ 1
- 43
test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py Ver arquivo

@@ -344,49 +344,6 @@ class TestDatasetCreate:
client.create_dataset(**payload)
assert "not instance of" in str(excinfo.value), str(excinfo.value)

@pytest.mark.p2
@pytest.mark.parametrize(
"name, pagerank",
[
("pagerank_min", 0),
("pagerank_mid", 50),
("pagerank_max", 100),
],
ids=["min", "mid", "max"],
)
def test_pagerank(self, client, name, pagerank):
payload = {"name": name, "pagerank": pagerank}
dataset = client.create_dataset(**payload)
assert dataset.pagerank == pagerank, str(dataset)

@pytest.mark.p3
@pytest.mark.parametrize(
"name, pagerank, expected_message",
[
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, client, name, pagerank, expected_message):
payload = {"name": name, "pagerank": pagerank}
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
assert expected_message in str(excinfo.value), str(excinfo.value)

@pytest.mark.p3
def test_pagerank_unset(self, client):
payload = {"name": "pagerank_unset"}
dataset = client.create_dataset(**payload)
assert dataset.pagerank == 0, str(dataset)

@pytest.mark.p3
def test_pagerank_none(self, client):
payload = {"name": "pagerank_unset", "pagerank": None}
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
assert "not instance of" in str(excinfo.value), str(excinfo.value)

@pytest.mark.p1
@pytest.mark.parametrize(
"name, parser_config",
@@ -689,6 +646,7 @@ class TestDatasetCreate:
{"name": "chunk_count", "chunk_count": 1},
{"name": "token_num", "token_num": 1},
{"name": "status", "status": "1"},
{"name": "pagerank", "pagerank": 50},
{"name": "unknown_field", "unknown_field": "unknown_field"},
],
)

+ 26
- 0
test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py Ver arquivo

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from operator import attrgetter

@@ -324,6 +325,7 @@ class TestDatasetUpdate:
dataset.update({"chunk_method": None})
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value)

@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
def test_pagerank(self, client, add_dataset_func, pagerank):
@@ -334,6 +336,30 @@ class TestDatasetUpdate:
retrieved_dataset = client.get_dataset(name=dataset.name)
assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset)

@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_set_to_0(self, client, add_dataset_func):
dataset = add_dataset_func
dataset.update({"pagerank": 50})
assert dataset.pagerank == 50, str(dataset)

retrieved_dataset = client.get_dataset(name=dataset.name)
assert retrieved_dataset.pagerank == 50, str(retrieved_dataset)

dataset.update({"pagerank": 0})
assert dataset.pagerank == 0, str(dataset)

retrieved_dataset = client.get_dataset(name=dataset.name)
assert retrieved_dataset.pagerank == 0, str(retrieved_dataset)

@pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_infinity(self, client, add_dataset_func):
dataset = add_dataset_func
with pytest.raises(Exception) as excinfo:
dataset.update({"pagerank": 50})
assert "'pagerank' can only be set when doc_engine is elasticsearch" in str(excinfo.value), str(excinfo.value)

@pytest.mark.p2
@pytest.mark.parametrize(
"pagerank, expected_message",

Carregando…
Cancelar
Salvar