Selaa lähdekoodia

Test: Add tests for chunk API endpoints (#8616)

### What problem does this PR solve?

- Add comprehensive test suite for chunk operations including:
  - Test files for create, list, retrieve, update, and delete chunks
  - Authorization tests
  - Batch operations tests
- Update test configurations and common utilities
- Validate `important_kwd` and `question_kwd` fields are lists in
chunk_app.py
- Reorganize imports and clean up duplicate code

### Type of change

- [x] Add test cases
tags/v0.20.0
Liu An 4 kuukautta sitten
vanhempi
commit
0b40eb3e90
No account linked to committer's email address

+ 18
- 14
api/apps/chunk_app.py Näytä tiedosto

@@ -15,27 +15,25 @@
#
import datetime
import json
import re

import xxhash
from flask import request
from flask_login import login_required, current_user
from flask_login import current_user, login_required

from rag.app.qa import rmPrefix, beAdoc
from rag.app.tag import label_question
from rag.nlp import search, rag_tokenizer
from rag.prompts import keyword_extraction, cross_languages
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace
from api import settings
from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService
from api import settings
from api.utils.api_utils import get_json_result
import xxhash
import re
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
from rag.prompts import cross_languages, keyword_extraction
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace


@manager.route('/list', methods=['POST']) # noqa: F821
@@ -129,9 +127,13 @@ def set():
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if "important_kwd" in req:
if not isinstance(req["important_kwd"], list):
return get_data_error_result(message="`important_kwd` should be a list")
d["important_kwd"] = req["important_kwd"]
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
if "question_kwd" in req:
if not isinstance(req["question_kwd"], list):
return get_data_error_result(message="`question_kwd` should be a list")
d["question_kwd"] = req["question_kwd"]
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
if "tag_kwd" in req:
@@ -235,6 +237,8 @@ def create():
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"]
if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"]

try:
e, doc = DocumentService.get_by_id(req["doc_id"])

+ 0
- 1
test/testcases/test_http_api/conftest.py Näytä tiedosto

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from time import sleep

import pytest

+ 40
- 3
test/testcases/test_web_api/common.py Näytä tiedosto

@@ -24,9 +24,7 @@ HEADERS = {"Content-Type": "application/json"}

KB_APP_URL = f"/{VERSION}/kb"
DOCUMENT_APP_URL = f"/{VERSION}/document"
# FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents"
# FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks"
# CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
CHUNK_API_URL = f"/{VERSION}/chunk"
# CHAT_ASSISTANT_API_URL = "/api/v1/chats"
# SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions"
# SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions"
@@ -164,3 +162,42 @@ def bulk_upload_documents(auth, kb_id, num, tmp_path):
for document in res["data"]:
document_ids.append(document["id"])
return document_ids


# CHUNK APP
def add_chunk(auth, payload=None, *, headers=HEADERS, data=None):
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/create", headers=headers, auth=auth, json=payload, data=data)
return res.json()


def list_chunks(auth, payload=None, *, headers=HEADERS):
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/list", headers=headers, auth=auth, json=payload)
return res.json()


def get_chunk(auth, params=None, *, headers=HEADERS):
res = requests.get(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/get", headers=headers, auth=auth, params=params)
return res.json()


def update_chunk(auth, payload=None, *, headers=HEADERS):
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/set", headers=headers, auth=auth, json=payload)
return res.json()


def delete_chunks(auth, payload=None, *, headers=HEADERS):
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/rm", headers=headers, auth=auth, json=payload)
return res.json()


def retrieval_chunks(auth, payload=None, *, headers=HEADERS):
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/retrieval_test", headers=headers, auth=auth, json=payload)
return res.json()


def batch_add_chunks(auth, doc_id, num):
chunk_ids = []
for i in range(num):
res = add_chunk(auth, {"doc_id": doc_id, "content_with_weight": f"chunk test {i}"})
chunk_ids.append(res["data"]["chunk_id"])
return chunk_ids

+ 50
- 9
test/testcases/test_web_api/conftest.py Näytä tiedosto

@@ -13,18 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from time import sleep

import pytest
from common import (
batch_add_chunks,
batch_create_datasets,
bulk_upload_documents,
delete_chunks,
list_chunks,
list_documents,
list_kbs,
parse_documents,
rm_kb,
)

# from configs import HOST_ADDRESS, VERSION
from libs.auth import RAGFlowWebApiAuth
from pytest import FixtureRequest

# from ragflow_sdk import RAGFlow
from utils import wait_for
from utils.file_utils import (
create_docx_file,
create_eml_file,
@@ -39,6 +44,15 @@ from utils.file_utils import (
)


@wait_for(30, 1, "Document parsing timeout")
def condition(_auth, _kb_id):
res = list_documents(_auth, {"kb_id": _kb_id})
for doc in res["data"]["docs"]:
if doc["run"] != "3":
return False
return True


@pytest.fixture
def generate_test_files(request: FixtureRequest, tmp_path):
file_creators = {
@@ -73,11 +87,6 @@ def WebApiAuth(auth):
return RAGFlowWebApiAuth(auth)


# @pytest.fixture(scope="session")
# def client(token: str) -> RAGFlow:
# return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION)


@pytest.fixture(scope="function")
def clear_datasets(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth):
def cleanup():
@@ -108,3 +117,35 @@ def add_dataset_func(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) ->

request.addfinalizer(cleanup)
return batch_create_datasets(WebApiAuth, 1)[0]


@pytest.fixture(scope="class")
def add_document(request, WebApiAuth, add_dataset, ragflow_tmp_dir):
# def cleanup():
# res = list_documents(WebApiAuth, {"kb_id": dataset_id})
# for doc in res["data"]["docs"]:
# delete_document(WebApiAuth, {"doc_id": doc["id"]})

# request.addfinalizer(cleanup)

dataset_id = add_dataset
return dataset_id, bulk_upload_documents(WebApiAuth, dataset_id, 1, ragflow_tmp_dir)[0]


@pytest.fixture(scope="class")
def add_chunks(request, WebApiAuth, add_document):
def cleanup():
res = list_chunks(WebApiAuth, {"doc_id": document_id})
if res["code"] == 0:
chunk_ids = [chunk["chunk_id"] for chunk in res["data"]["chunks"]]
delete_chunks(WebApiAuth, {"doc_id": document_id, "chunk_ids": chunk_ids})

request.addfinalizer(cleanup)

kb_id, document_id = add_document
parse_documents(WebApiAuth, {"doc_ids": [document_id], "run": "1"})
condition(WebApiAuth, kb_id)
chunk_ids = batch_add_chunks(WebApiAuth, document_id, 4)
# issues/6487
sleep(1)
return kb_id, document_id, chunk_ids

+ 49
- 0
test/testcases/test_web_api/test_chunk_app/conftest.py Näytä tiedosto

@@ -0,0 +1,49 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from time import sleep

import pytest
from common import batch_add_chunks, delete_chunks, list_chunks, list_documents, parse_documents
from utils import wait_for


@wait_for(30, 1, "Document parsing timeout")
def condition(_auth, _kb_id):
res = list_documents(_auth, {"kb_id": _kb_id})
for doc in res["data"]["docs"]:
if doc["run"] != "3":
return False
return True


@pytest.fixture(scope="function")
def add_chunks_func(request, WebApiAuth, add_document):
def cleanup():
res = list_chunks(WebApiAuth, {"doc_id": document_id})
chunk_ids = [chunk["chunk_id"] for chunk in res["data"]["chunks"]]
delete_chunks(WebApiAuth, {"doc_id": document_id, "chunk_ids": chunk_ids})

request.addfinalizer(cleanup)

kb_id, document_id = add_document
parse_documents(WebApiAuth, {"doc_ids": [document_id], "run": "1"})
condition(WebApiAuth, kb_id)
chunk_ids = batch_add_chunks(WebApiAuth, document_id, 4)
# issues/6487
sleep(1)
return kb_id, document_id, chunk_ids

+ 223
- 0
test/testcases/test_web_api/test_chunk_app/test_create_chunk.py Näytä tiedosto

@@ -0,0 +1,223 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
from common import add_chunk, delete_document, get_chunk, list_chunks
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth


def validate_chunk_details(auth, kb_id, doc_id, payload, res):
chunk_id = res["data"]["chunk_id"]
res = get_chunk(auth, {"chunk_id": chunk_id})
assert res["code"] == 0, res
chunk = res["data"]
assert chunk["doc_id"] == doc_id
assert chunk["kb_id"] == kb_id
assert chunk["content_with_weight"] == payload["content_with_weight"]
if "important_kwd" in payload:
assert chunk["important_kwd"] == payload["important_kwd"]
if "question_kwd" in payload:
expected = [str(q).strip() for q in payload.get("question_kwd", [])]
assert chunk["question_kwd"] == expected


@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 401, "<Unauthorized '401: Unauthorized'>"),
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = add_chunk(invalid_auth)
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res


class TestAddChunk:
@pytest.mark.p1
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"content_with_weight": None}, 100, """TypeError("unsupported operand type(s) for +: 'NoneType' and 'str'")"""),
({"content_with_weight": ""}, 0, ""),
pytest.param(
{"content_with_weight": 1},
100,
"""TypeError("unsupported operand type(s) for +: 'int' and 'str'")""",
marks=pytest.mark.skip,
),
({"content_with_weight": "a"}, 0, ""),
({"content_with_weight": " "}, 0, ""),
({"content_with_weight": "\n!?。;!?\"'"}, 0, ""),
],
)
def test_content(self, WebApiAuth, add_document, payload, expected_code, expected_message):
kb_id, doc_id = add_document
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
if res["code"] == 0:
chunks_count = res["data"]["doc"]["chunk_num"]
else:
chunks_count = 0
res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id})
assert res["code"] == expected_code, res
if expected_code == 0:
validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
assert res["code"] == 0, res
assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res
else:
assert res["message"] == expected_message, res

@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"content_with_weight": "chunk test", "important_kwd": ["a", "b", "c"]}, 0, ""),
({"content_with_weight": "chunk test", "important_kwd": [""]}, 0, ""),
(
{"content_with_weight": "chunk test", "important_kwd": [1]},
100,
"TypeError('sequence item 0: expected str instance, int found')",
),
({"content_with_weight": "chunk test", "important_kwd": ["a", "a"]}, 0, ""),
({"content_with_weight": "chunk test", "important_kwd": "abc"}, 102, "`important_kwd` is required to be a list"),
({"content_with_weight": "chunk test", "important_kwd": 123}, 102, "`important_kwd` is required to be a list"),
],
)
def test_important_keywords(self, WebApiAuth, add_document, payload, expected_code, expected_message):
kb_id, doc_id = add_document
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
if res["code"] == 0:
chunks_count = res["data"]["doc"]["chunk_num"]
else:
chunks_count = 0
res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id})
assert res["code"] == expected_code, res
if expected_code == 0:
validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
assert res["code"] == 0, res
assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res
else:
assert res["message"] == expected_message, res

@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"content_with_weight": "chunk test", "question_kwd": ["a", "b", "c"]}, 0, ""),
({"content_with_weight": "chunk test", "question_kwd": [""]}, 0, ""),
({"content_with_weight": "chunk test", "question_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
({"content_with_weight": "chunk test", "question_kwd": ["a", "a"]}, 0, ""),
({"content_with_weight": "chunk test", "question_kwd": "abc"}, 102, "`question_kwd` is required to be a list"),
({"content_with_weight": "chunk test", "question_kwd": 123}, 102, "`question_kwd` is required to be a list"),
],
)
def test_questions(self, WebApiAuth, add_document, payload, expected_code, expected_message):
kb_id, doc_id = add_document
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
if res["code"] == 0:
chunks_count = res["data"]["doc"]["chunk_num"]
else:
chunks_count = 0
res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id})
assert res["code"] == expected_code, res
if expected_code == 0:
validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
assert res["code"] == 0, res
assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res
else:
assert res["message"] == expected_message, res

@pytest.mark.p3
@pytest.mark.parametrize(
"doc_id, expected_code, expected_message",
[
("", 102, "Document not found!"),
("invalid_document_id", 102, "Document not found!"),
],
)
def test_invalid_document_id(self, WebApiAuth, add_document, doc_id, expected_code, expected_message):
_, _ = add_document
res = add_chunk(WebApiAuth, {"doc_id": doc_id, "content_with_weight": "chunk test"})
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res

@pytest.mark.p3
def test_repeated_add_chunk(self, WebApiAuth, add_document):
payload = {"content_with_weight": "chunk test"}
kb_id, doc_id = add_document
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
if res["code"] != 0:
assert False, res
chunks_count = res["data"]["doc"]["chunk_num"]

res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id})
assert res["code"] == 0, res
validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
if res["code"] != 0:
assert False, res
assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res

res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id})
assert res["code"] == 0, res
validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
if res["code"] != 0:
assert False, res
assert res["data"]["doc"]["chunk_num"] == chunks_count + 2, res

@pytest.mark.p2
def test_add_chunk_to_deleted_document(self, WebApiAuth, add_document):
_, doc_id = add_document
delete_document(WebApiAuth, {"doc_id": doc_id})
res = add_chunk(WebApiAuth, {"doc_id": doc_id, "content_with_weight": "chunk test"})
assert res["code"] == 102, res
assert res["message"] == "Document not found!", res

@pytest.mark.skip(reason="issues/6411")
@pytest.mark.p3
def test_concurrent_add_chunk(self, WebApiAuth, add_document):
count = 50
_, doc_id = add_document
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
if res["code"] == 0:
chunks_count = res["data"]["doc"]["chunk_num"]
else:
chunks_count = 0

with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
add_chunk,
WebApiAuth,
{"doc_id": doc_id, "content_with_weight": f"chunk test {i}"},
)
for i in range(count)
]
responses = list(as_completed(futures))
assert len(responses) == count, responses
assert all(future.result()["code"] == 0 for future in futures)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
assert res["code"] == 0, res
assert res["data"]["doc"]["chunk_num"] == chunks_count + count

+ 145
- 0
test/testcases/test_web_api/test_chunk_app/test_list_chunks.py Näytä tiedosto

@@ -0,0 +1,145 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
from common import batch_add_chunks, list_chunks
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth


@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 401, "<Unauthorized '401: Unauthorized'>"),
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = list_chunks(invalid_auth, {"doc_id": "document_id"})
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res


class TestChunksList:
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
pytest.param({"page": None, "size": 2}, 100, 0, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", marks=pytest.mark.skip),
pytest.param({"page": 0, "size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip),
({"page": 2, "size": 2}, 0, 2, ""),
({"page": 3, "size": 2}, 0, 1, ""),
({"page": "3", "size": 2}, 0, 1, ""),
pytest.param({"page": -1, "size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip),
pytest.param({"page": "a", "size": 2}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip),
],
)
def test_page(self, WebApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message):
_, doc_id, _ = add_chunks
payload = {"doc_id": doc_id}
if params:
payload.update(params)
res = list_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res

@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"size": None}, 100, 0, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")"""),
pytest.param({"size": 0}, 0, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")),
pytest.param({"size": 0}, 100, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")),
({"size": 1}, 0, 1, ""),
({"size": 6}, 0, 5, ""),
({"size": "1"}, 0, 1, ""),
pytest.param({"size": -1}, 0, 5, "", marks=pytest.mark.skip),
pytest.param({"size": "a"}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip),
],
)
def test_page_size(self, WebApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message):
_, doc_id, _ = add_chunks
payload = {"doc_id": doc_id}
if params:
payload.update(params)
res = list_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res

@pytest.mark.p2
@pytest.mark.parametrize(
"params, expected_page_size",
[
({"keywords": None}, 5),
({"keywords": ""}, 5),
({"keywords": "1"}, 1),
pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")),
({"keywords": "content"}, 1),
({"keywords": "unknown"}, 0),
],
)
def test_keywords(self, WebApiAuth, add_chunks, params, expected_page_size):
_, doc_id, _ = add_chunks
payload = {"doc_id": doc_id}
if params:
payload.update(params)
res = list_chunks(WebApiAuth, payload)
assert res["code"] == 0, res
assert len(res["data"]["chunks"]) == expected_page_size, res

@pytest.mark.p3
def test_invalid_params(self, WebApiAuth, add_chunks):
_, doc_id, _ = add_chunks
payload = {"doc_id": doc_id, "a": "b"}
res = list_chunks(WebApiAuth, payload)
assert res["code"] == 0, res
assert len(res["data"]["chunks"]) == 5, res

@pytest.mark.p3
def test_concurrent_list(self, WebApiAuth, add_chunks):
_, doc_id, _ = add_chunks
count = 100
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_chunks, WebApiAuth, {"doc_id": doc_id}) for i in range(count)]
responses = list(as_completed(futures))
assert len(responses) == count, responses
assert all(len(future.result()["data"]["chunks"]) == 5 for future in futures)

@pytest.mark.p1
def test_default(self, WebApiAuth, add_document):
_, doc_id = add_document

res = list_chunks(WebApiAuth, {"doc_id": doc_id})
chunks_count = res["data"]["doc"]["chunk_num"]
batch_add_chunks(WebApiAuth, doc_id, 31)
# issues/6487
from time import sleep

sleep(3)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
assert res["code"] == 0
assert len(res["data"]["chunks"]) == 30
assert res["data"]["doc"]["chunk_num"] == chunks_count + 31

+ 308
- 0
test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py Näytä tiedosto

@@ -0,0 +1,308 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
from common import retrieval_chunks
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth


@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 401, "<Unauthorized '401: Unauthorized'>"),
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = retrieval_chunks(invalid_auth, {"kb_id": "dummy_kb_id", "question": "dummy question"})
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res


class TestChunksRetrieval:
@pytest.mark.p1
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"question": "chunk", "kb_id": None}, 0, 4, ""),
({"question": "chunk", "doc_ids": None}, 101, 0, "required argument are missing: kb_id; "),
({"question": "chunk", "kb_id": None, "doc_ids": None}, 0, 4, ""),
({"question": "chunk"}, 101, 0, "required argument are missing: kb_id; "),
],
)
def test_basic_scenarios(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, document_id, _ = add_chunks
if "kb_id" in payload:
payload["kb_id"] = [dataset_id]
if "doc_ids" in payload:
payload["doc_ids"] = [document_id]
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res

@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
pytest.param(
{"page": None, "size": 2},
100,
0,
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""",
marks=pytest.mark.skip,
),
pytest.param(
{"page": 0, "size": 2},
100,
0,
"ValueError('Search does not support negative slicing.')",
marks=pytest.mark.skip,
),
pytest.param({"page": 2, "size": 2}, 0, 2, "", marks=pytest.mark.skip(reason="issues/6646")),
({"page": 3, "size": 2}, 0, 0, ""),
({"page": "3", "size": 2}, 0, 0, ""),
pytest.param(
{"page": -1, "size": 2},
100,
0,
"ValueError('Search does not support negative slicing.')",
marks=pytest.mark.skip,
),
pytest.param(
{"page": "a", "size": 2},
100,
0,
"""ValueError("invalid literal for int() with base 10: 'a'")""",
marks=pytest.mark.skip,
),
],
)
def test_page(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res

@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
pytest.param(
{"size": None},
100,
0,
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""",
marks=pytest.mark.skip,
),
# ({"size": 0}, 0, 0, ""),
({"size": 1}, 0, 1, ""),
({"size": 5}, 0, 4, ""),
({"size": "1"}, 0, 1, ""),
# ({"size": -1}, 0, 0, ""),
pytest.param(
{"size": "a"},
100,
0,
"""ValueError("invalid literal for int() with base 10: 'a'")""",
marks=pytest.mark.skip,
),
],
)
def test_page_size(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})

res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res

@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"vector_similarity_weight": 0}, 0, 4, ""),
({"vector_similarity_weight": 0.5}, 0, 4, ""),
({"vector_similarity_weight": 10}, 0, 4, ""),
pytest.param(
{"vector_similarity_weight": "a"},
100,
0,
"""ValueError("could not convert string to float: 'a'")""",
marks=pytest.mark.skip,
),
],
)
def test_vector_similarity_weight(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res

@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"top_k": 10}, 0, 4, ""),
pytest.param(
{"top_k": 1},
0,
4,
"",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"),
),
pytest.param(
{"top_k": 1},
0,
1,
"",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"),
),
pytest.param(
{"top_k": -1},
100,
4,
"must be greater than 0",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"),
),
pytest.param(
{"top_k": -1},
100,
4,
"3014",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"),
),
pytest.param(
{"top_k": "a"},
100,
0,
"""ValueError("invalid literal for int() with base 10: 'a'")""",
marks=pytest.mark.skip,
),
],
)
def test_top_k(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert expected_message in res["message"], res

@pytest.mark.skip
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""),
pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip),
],
)
def test_rerank_id(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) > 0, res
else:
assert expected_message in res["message"], res

@pytest.mark.skip
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"keyword": True}, 0, 5, ""),
({"keyword": "True"}, 0, 5, ""),
({"keyword": False}, 0, 5, ""),
({"keyword": "False"}, 0, 5, ""),
({"keyword": None}, 0, 5, ""),
],
)
def test_keyword(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk test", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size, res
else:
assert res["message"] == expected_message, res

@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_highlight, expected_message",
[
({"highlight": True}, 0, True, ""),
({"highlight": "True"}, 0, True, ""),
pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
pytest.param({"highlight": "False"}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
],
)
def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "kb_id": [dataset_id]})
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if expected_highlight:
for chunk in res["data"]["chunks"]:
assert "highlight" in chunk, res
else:
for chunk in res["data"]["chunks"]:
assert "highlight" not in chunk, res

if expected_code != 0:
assert res["message"] == expected_message, res

@pytest.mark.p3
def test_invalid_params(self, WebApiAuth, add_chunks):
dataset_id, _, _ = add_chunks
payload = {"question": "chunk", "kb_id": [dataset_id], "a": "b"}
res = retrieval_chunks(WebApiAuth, payload)
assert res["code"] == 0, res
assert len(res["data"]["chunks"]) == 4, res

@pytest.mark.p3
def test_concurrent_retrieval(self, WebApiAuth, add_chunks):
dataset_id, _, _ = add_chunks
count = 100
payload = {"question": "chunk", "kb_id": [dataset_id]}

with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(retrieval_chunks, WebApiAuth, payload) for i in range(count)]
responses = list(as_completed(futures))
assert len(responses) == count, responses
assert all(future.result()["code"] == 0 for future in futures)

+ 161
- 0
test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py Näytä tiedosto

@@ -0,0 +1,161 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
from common import batch_add_chunks, delete_chunks, list_chunks
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth


@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 401, "<Unauthorized '401: Unauthorized'>"),
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = delete_chunks(invalid_auth, {"doc_id": "document_id", "chunk_ids": ["1"]})
assert res["code"] == expected_code
assert res["message"] == expected_message


class TestChunksDeletion:
@pytest.mark.p3
@pytest.mark.parametrize(
"doc_id, expected_code, expected_message",
[
("", 102, "Document not found!"),
("invalid_document_id", 102, "Document not found!"),
],
)
def test_invalid_document_id(self, WebApiAuth, add_chunks_func, doc_id, expected_code, expected_message):
_, _, chunk_ids = add_chunks_func
res = delete_chunks(WebApiAuth, {"doc_id": doc_id, "chunk_ids": chunk_ids})
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res

@pytest.mark.parametrize(
"payload",
[
pytest.param(lambda r: {"chunk_ids": ["invalid_id"] + r}, marks=pytest.mark.p3),
pytest.param(lambda r: {"chunk_ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1),
pytest.param(lambda r: {"chunk_ids": r + ["invalid_id"]}, marks=pytest.mark.p3),
],
)
def test_delete_partial_invalid_id(self, WebApiAuth, add_chunks_func, payload):
_, doc_id, chunk_ids = add_chunks_func
if callable(payload):
payload = payload(chunk_ids)
payload["doc_id"] = doc_id
res = delete_chunks(WebApiAuth, payload)
assert res["code"] == 0, res

res = list_chunks(WebApiAuth, {"doc_id": doc_id})
assert res["code"] == 0, res
assert len(res["data"]["chunks"]) == 0, res
assert res["data"]["total"] == 0, res

@pytest.mark.p3
def test_repeated_deletion(self, WebApiAuth, add_chunks_func):
_, doc_id, chunk_ids = add_chunks_func
payload = {"chunk_ids": chunk_ids, "doc_id": doc_id}
res = delete_chunks(WebApiAuth, payload)
assert res["code"] == 0, res

res = delete_chunks(WebApiAuth, payload)
assert res["code"] == 102, res
assert res["message"] == "Index updating failure", res

@pytest.mark.p3
def test_duplicate_deletion(self, WebApiAuth, add_chunks_func):
_, doc_id, chunk_ids = add_chunks_func
payload = {"chunk_ids": chunk_ids * 2, "doc_id": doc_id}
res = delete_chunks(WebApiAuth, payload)
assert res["code"] == 0, res

res = list_chunks(WebApiAuth, {"doc_id": doc_id})
assert res["code"] == 0, res
assert len(res["data"]["chunks"]) == 0, res
assert res["data"]["total"] == 0, res

@pytest.mark.p3
def test_concurrent_deletion(self, WebApiAuth, add_document):
count = 100
_, doc_id = add_document
chunk_ids = batch_add_chunks(WebApiAuth, doc_id, count)

with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
delete_chunks,
WebApiAuth,
{"doc_id": doc_id, "chunk_ids": chunk_ids[i : i + 1]},
)
for i in range(count)
]
responses = list(as_completed(futures))
assert len(responses) == count, responses
assert all(future.result()["code"] == 0 for future in futures)

@pytest.mark.p3
def test_delete_1k(self, WebApiAuth, add_document):
chunks_num = 1_000
_, doc_id = add_document
chunk_ids = batch_add_chunks(WebApiAuth, doc_id, chunks_num)

from time import sleep

sleep(1)

res = delete_chunks(WebApiAuth, {"doc_id": doc_id, "chunk_ids": chunk_ids})
assert res["code"] == 0

res = list_chunks(WebApiAuth, {"doc_id": doc_id})
if res["code"] != 0:
assert False, res
assert len(res["data"]["chunks"]) == 0, res
assert res["data"]["total"] == 0, res

@pytest.mark.parametrize(
"payload, expected_code, expected_message, remaining",
[
pytest.param(None, 100, """TypeError("argument of type \'NoneType\' is not iterable")""", 5, marks=pytest.mark.skip),
pytest.param({"chunk_ids": ["invalid_id"]}, 102, "Index updating failure", 4, marks=pytest.mark.p3),
pytest.param("not json", 100, """UnboundLocalError("local variable \'duplicate_messages\' referenced before assignment")""", 5, marks=pytest.mark.skip(reason="pull/6376")),
pytest.param(lambda r: {"chunk_ids": r[:1]}, 0, "", 3, marks=pytest.mark.p3),
pytest.param(lambda r: {"chunk_ids": r}, 0, "", 0, marks=pytest.mark.p1),
pytest.param({"chunk_ids": []}, 0, "", 0, marks=pytest.mark.p3),
],
)
def test_basic_scenarios(self, WebApiAuth, add_chunks_func, payload, expected_code, expected_message, remaining):
_, doc_id, chunk_ids = add_chunks_func
if callable(payload):
payload = payload(chunk_ids)
payload["doc_id"] = doc_id
res = delete_chunks(WebApiAuth, payload)
assert res["code"] == expected_code, res
if res["code"] != 0:
assert res["message"] == expected_message, res

res = list_chunks(WebApiAuth, {"doc_id": doc_id})
if res["code"] != 0:
assert False, res
assert len(res["data"]["chunks"]) == remaining, res
assert res["data"]["total"] == remaining, res

+ 232
- 0
test/testcases/test_web_api/test_chunk_app/test_update_chunk.py Näytä tiedosto

@@ -0,0 +1,232 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from random import randint
from time import sleep

import pytest
from common import delete_document, list_chunks, update_chunk
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth


@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 401, "<Unauthorized '401: Unauthorized'>"),
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = update_chunk(invalid_auth, {"doc_id": "doc_id", "chunk_id": "chunk_id", "content_with_weight": "test"})
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res


class TestUpdateChunk:
@pytest.mark.p1
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"content_with_weight": None}, 100, "TypeError('expected string or bytes-like object')"),
({"content_with_weight": ""}, 0, ""),
({"content_with_weight": 1}, 100, "TypeError('expected string or bytes-like object')"),
({"content_with_weight": "update chunk"}, 0, ""),
({"content_with_weight": " "}, 0, ""),
({"content_with_weight": "\n!?。;!?\"'"}, 0, ""),
],
)
def test_content(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
_, doc_id, chunk_ids = add_chunks
chunk_id = chunk_ids[0]
update_payload = {"doc_id": doc_id, "chunk_id": chunk_id}
if payload:
update_payload.update(payload)
res = update_chunk(WebApiAuth, update_payload)
assert res["code"] == expected_code, res
if expected_code != 0:
assert res["message"] == expected_message, res
else:
sleep(1)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
for chunk in res["data"]["chunks"]:
if chunk["chunk_id"] == chunk_id:
assert chunk["content_with_weight"] == payload["content_with_weight"]

@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"important_kwd": ["a", "b", "c"]}, 0, ""),
({"important_kwd": [""]}, 0, ""),
({"important_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
({"important_kwd": ["a", "a"]}, 0, ""),
({"important_kwd": "abc"}, 102, "`important_kwd` should be a list"),
({"important_kwd": 123}, 102, "`important_kwd` should be a list"),
],
)
def test_important_keywords(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
_, doc_id, chunk_ids = add_chunks
chunk_id = chunk_ids[0]
update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} # Add content_with_weight as it's required
if payload:
update_payload.update(payload)
res = update_chunk(WebApiAuth, update_payload)
assert res["code"] == expected_code, res
if expected_code != 0:
assert res["message"] == expected_message, res
else:
sleep(1)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
for chunk in res["data"]["chunks"]:
if chunk["chunk_id"] == chunk_id:
assert chunk["important_kwd"] == payload["important_kwd"]

@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"question_kwd": ["a", "b", "c"]}, 0, ""),
({"question_kwd": [""]}, 0, ""),
({"question_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
({"question_kwd": ["a", "a"]}, 0, ""),
({"question_kwd": "abc"}, 102, "`question_kwd` should be a list"),
({"question_kwd": 123}, 102, "`question_kwd` should be a list"),
],
)
def test_questions(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
_, doc_id, chunk_ids = add_chunks
chunk_id = chunk_ids[0]
update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} # Add content_with_weight as it's required
if payload:
update_payload.update(payload)

res = update_chunk(WebApiAuth, update_payload)
assert res["code"] == expected_code, res
if expected_code != 0:
assert res["message"] == expected_message, res
else:
sleep(1)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
for chunk in res["data"]["chunks"]:
if chunk["chunk_id"] == chunk_id:
assert chunk["question_kwd"] == payload["question_kwd"]

@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"available_int": 1}, 0, ""),
({"available_int": 0}, 0, ""),
],
)
def test_available(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
_, doc_id, chunk_ids = add_chunks
chunk_id = chunk_ids[0]
update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"}
if payload:
update_payload.update(payload)

res = update_chunk(WebApiAuth, update_payload)
assert res["code"] == expected_code, res
if expected_code != 0:
assert res["message"] == expected_message, res
else:
sleep(1)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
for chunk in res["data"]["chunks"]:
if chunk["chunk_id"] == chunk_id:
assert chunk["available_int"] == payload["available_int"]

@pytest.mark.p3
@pytest.mark.parametrize(
"doc_id_param, expected_code, expected_message",
[
("", 102, "Tenant not found!"),
("invalid_doc_id", 102, "Tenant not found!"),
],
)
def test_invalid_document_id_for_update(self, WebApiAuth, add_chunks, doc_id_param, expected_code, expected_message):
_, _, chunk_ids = add_chunks
chunk_id = chunk_ids[0]

payload = {"doc_id": doc_id_param, "chunk_id": chunk_id, "content_with_weight": "test content"}
res = update_chunk(WebApiAuth, payload)
assert res["code"] == expected_code
assert expected_message in res["message"]

@pytest.mark.p3
def test_repeated_update_chunk(self, WebApiAuth, add_chunks):
_, doc_id, chunk_ids = add_chunks
payload1 = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "chunk test 1"}
res = update_chunk(WebApiAuth, payload1)
assert res["code"] == 0

payload2 = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "chunk test 2"}
res = update_chunk(WebApiAuth, payload2)
assert res["code"] == 0

@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"unknown_key": "unknown_value"}, 0, ""),
({}, 0, ""),
pytest.param(None, 100, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", marks=pytest.mark.skip),
],
)
def test_invalid_params(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
_, doc_id, chunk_ids = add_chunks
chunk_id = chunk_ids[0]
update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"}
if payload is not None:
update_payload.update(payload)

res = update_chunk(WebApiAuth, update_payload)
assert res["code"] == expected_code, res
if expected_code != 0:
assert res["message"] == expected_message, res

@pytest.mark.p3
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554")
def test_concurrent_update_chunk(self, WebApiAuth, add_chunks):
count = 50
_, doc_id, chunk_ids = add_chunks

with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
update_chunk,
WebApiAuth,
{"doc_id": doc_id, "chunk_id": chunk_ids[randint(0, 3)], "content_with_weight": f"update chunk test {i}"},
)
for i in range(count)
]
responses = list(as_completed(futures))
assert len(responses) == count, responses
assert all(future.result()["code"] == 0 for future in futures)

@pytest.mark.p3
def test_update_chunk_to_deleted_document(self, WebApiAuth, add_chunks):
_, doc_id, chunk_ids = add_chunks
delete_document(WebApiAuth, {"doc_id": doc_id})
payload = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "test content"}
res = update_chunk(WebApiAuth, payload)
assert res["code"] == 102, res
assert res["message"] == "Tenant not found!", res

Loading…
Peruuta
Tallenna