### What problem does this PR solve? - Improve concurrent test cases by using as_completed for better reliability - Rename variables for clarity (chunk_num -> count) - Add new SDK API test suite for chunk management operations - Update HTTP API tests with consistent concurrency patterns ### Type of change - [x] Add test cases - [x] Refactoringtags/v0.19.1
| @@ -13,7 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import INVALID_API_TOKEN, add_chunk, delete_documents, list_chunks | |||
| @@ -224,7 +224,7 @@ class TestAddChunk: | |||
| @pytest.mark.skip(reason="issues/6411") | |||
| def test_concurrent_add_chunk(self, api_key, add_document): | |||
| chunk_num = 50 | |||
| count = 50 | |||
| dataset_id, document_id = add_document | |||
| res = list_chunks(api_key, dataset_id, document_id) | |||
| if res["code"] != 0: | |||
| @@ -240,11 +240,12 @@ class TestAddChunk: | |||
| document_id, | |||
| {"content": f"chunk test {i}"}, | |||
| ) | |||
| for i in range(chunk_num) | |||
| for i in range(count) | |||
| ] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| res = list_chunks(api_key, dataset_id, document_id) | |||
| if res["code"] != 0: | |||
| assert False, res | |||
| assert res["data"]["doc"]["chunk_count"] == chunks_count + chunk_num | |||
| assert res["data"]["doc"]["chunk_count"] == chunks_count + count | |||
| @@ -13,7 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import INVALID_API_TOKEN, batch_add_chunks, delete_chunks, list_chunks | |||
| @@ -121,9 +121,9 @@ class TestChunksDeletion: | |||
| @pytest.mark.p3 | |||
| def test_concurrent_deletion(self, api_key, add_document): | |||
| chunks_num = 100 | |||
| count = 100 | |||
| dataset_id, document_id = add_document | |||
| chunk_ids = batch_add_chunks(api_key, dataset_id, document_id, chunks_num) | |||
| chunk_ids = batch_add_chunks(api_key, dataset_id, document_id, count) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [ | |||
| @@ -134,10 +134,11 @@ class TestChunksDeletion: | |||
| document_id, | |||
| {"chunk_ids": chunk_ids[i : i + 1]}, | |||
| ) | |||
| for i in range(chunks_num) | |||
| for i in range(count) | |||
| ] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @pytest.mark.p3 | |||
| def test_delete_1k(self, api_key, add_document): | |||
| @@ -14,7 +14,7 @@ | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import INVALID_API_TOKEN, batch_add_chunks, list_chunks | |||
| @@ -149,12 +149,12 @@ class TestChunksList: | |||
| @pytest.mark.p3 | |||
| def test_concurrent_list(self, api_key, add_chunks): | |||
| dataset_id, document_id, _ = add_chunks | |||
| count = 100 | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(list_chunks, api_key, dataset_id, document_id) for i in range(100)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| assert all(len(r["data"]["chunks"]) == 5 for r in responses) | |||
| futures = [executor.submit(list_chunks, api_key, dataset_id, document_id) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(len(future.result()["data"]["chunks"]) == 5 for future in futures) | |||
| @pytest.mark.p1 | |||
| def test_default(self, api_key, add_document): | |||
| @@ -14,6 +14,7 @@ | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import ( | |||
| @@ -302,12 +303,12 @@ class TestChunksRetrieval: | |||
| @pytest.mark.p3 | |||
| def test_concurrent_retrieval(self, api_key, add_chunks): | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| dataset_id, _, _ = add_chunks | |||
| count = 100 | |||
| payload = {"question": "chunk", "dataset_ids": [dataset_id]} | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(retrieval_chunks, api_key, payload) for i in range(100)] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| futures = [executor.submit(retrieval_chunks, api_key, payload) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @@ -14,7 +14,7 @@ | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| from random import randint | |||
| import pytest | |||
| @@ -219,7 +219,7 @@ class TestUpdatedChunk: | |||
| @pytest.mark.p3 | |||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554") | |||
| def test_concurrent_update_chunk(self, api_key, add_chunks): | |||
| chunk_num = 50 | |||
| count = 50 | |||
| dataset_id, document_id, chunk_ids = add_chunks | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| @@ -232,10 +232,11 @@ class TestUpdatedChunk: | |||
| chunk_ids[randint(0, 3)], | |||
| {"content": f"update chunk test {i}"}, | |||
| ) | |||
| for i in range(chunk_num) | |||
| for i in range(count) | |||
| ] | |||
| responses = [f.result() for f in futures] | |||
| assert all(r["code"] == 0 for r in responses) | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @pytest.mark.p3 | |||
| def test_update_chunk_to_deleted_document(self, api_key, add_chunks): | |||
| @@ -85,7 +85,7 @@ class TestCapability: | |||
| futures = [executor.submit(create_dataset, api_key, {"name": f"dataset_{i}"}) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(futures.result()["code"] == 0 for futures in futures) | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @pytest.mark.usefixtures("clear_datasets") | |||
| @@ -93,7 +93,7 @@ class TestCapability: | |||
| futures = [executor.submit(delete_datasets, api_key, {"ids": ids[i : i + 1]}) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(futures.result()["code"] == 0 for futures in futures) | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| class TestDatasetsDelete: | |||
| @@ -49,7 +49,7 @@ class TestCapability: | |||
| futures = [executor.submit(list_datasets, api_key) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(futures.result()["code"] == 0 for futures in futures) | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @pytest.mark.usefixtures("add_datasets") | |||
| @@ -95,7 +95,7 @@ class TestCapability: | |||
| futures = [executor.submit(update_dataset, api_key, dataset_id, {"name": f"dataset_{i}"}) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(futures.result()["code"] == 0 for futures in futures) | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| class TestDatasetUpdate: | |||
| @@ -15,7 +15,6 @@ | |||
| # | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import INVALID_API_TOKEN, bulk_upload_documents, delete_documents, list_documents | |||
| from libs.auth import RAGFlowHttpApiAuth | |||
| @@ -165,7 +164,7 @@ def test_concurrent_deletion(api_key, add_dataset, tmp_path): | |||
| ] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(futures.result()["code"] == 0 for futures in futures) | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @pytest.mark.p3 | |||
| @@ -348,7 +348,7 @@ class TestDocumentsList: | |||
| futures = [executor.submit(list_documents, api_key, dataset_id) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(futures.result()["code"] == 0 for futures in futures) | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @pytest.mark.p3 | |||
| def test_invalid_params(self, api_key, add_documents): | |||
| @@ -211,7 +211,7 @@ def test_concurrent_parse(api_key, add_dataset_func, tmp_path): | |||
| ] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(futures.result()["code"] == 0 for futures in futures) | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| condition(api_key, dataset_id, count) | |||
| @@ -213,7 +213,7 @@ class TestDocumentsUpload: | |||
| futures = [executor.submit(upload_documents, api_key, dataset_id, fps[i : i + 1]) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(futures.result()["code"] == 0 for futures in futures) | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| res = list_datasets(api_key, {"id": dataset_id}) | |||
| assert res["data"][0]["document_count"] == count | |||
| @@ -22,11 +22,7 @@ from utils.file_utils import create_txt_file | |||
| # DATASET MANAGEMENT | |||
| def batch_create_datasets(client: RAGFlow, num: int) -> list[DataSet]: | |||
| datasets = [] | |||
| for i in range(num): | |||
| dataset = client.create_dataset(name=f"dataset_{i}") | |||
| datasets.append(dataset) | |||
| return datasets | |||
| return [client.create_dataset(name=f"dataset_{i}") for i in range(num)] | |||
| # FILE MANAGEMENT WITHIN DATASET | |||
| @@ -39,3 +35,8 @@ def bulk_upload_documents(dataset: DataSet, num: int, tmp_path: Path) -> list[Do | |||
| document_infos.append({"display_name": fp.name, "blob": blob}) | |||
| return dataset.upload_documents(document_infos) | |||
| # CHUNK MANAGEMENT WITHIN DATASET | |||
| def batch_add_chunks(document: Document, num: int): | |||
| return [document.add_chunk(content=f"chunk test {i}") for i in range(num)] | |||
| @@ -23,7 +23,7 @@ from common import ( | |||
| ) | |||
| from configs import HOST_ADDRESS, VERSION | |||
| from pytest import FixtureRequest | |||
| from ragflow_sdk import DataSet, RAGFlow | |||
| from ragflow_sdk import Chunk, DataSet, Document, RAGFlow | |||
| from utils import wait_for | |||
| from utils.file_utils import ( | |||
| create_docx_file, | |||
| @@ -41,7 +41,7 @@ from utils.file_utils import ( | |||
| @wait_for(30, 1, "Document parsing timeout") | |||
| def condition(_dataset: DataSet): | |||
| documents = DataSet.list_documents(page_size=1000) | |||
| documents = _dataset.list_documents(page_size=1000) | |||
| for document in documents: | |||
| if document.run != "DONE": | |||
| return False | |||
| @@ -49,7 +49,7 @@ def condition(_dataset: DataSet): | |||
| @pytest.fixture | |||
| def generate_test_files(request, tmp_path): | |||
| def generate_test_files(request: FixtureRequest, tmp_path: Path): | |||
| file_creators = { | |||
| "docx": (tmp_path / "ragflow_test.docx", create_docx_file), | |||
| "excel": (tmp_path / "ragflow_test.xlsx", create_excel_file), | |||
| @@ -72,13 +72,13 @@ def generate_test_files(request, tmp_path): | |||
| @pytest.fixture(scope="class") | |||
| def ragflow_tmp_dir(request, tmp_path_factory) -> Path: | |||
| def ragflow_tmp_dir(request: FixtureRequest, tmp_path_factory: Path) -> Path: | |||
| class_name = request.cls.__name__ | |||
| return tmp_path_factory.mktemp(class_name) | |||
| @pytest.fixture(scope="session") | |||
| def client(token) -> RAGFlow: | |||
| def client(token: str) -> RAGFlow: | |||
| return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION) | |||
| @@ -96,9 +96,7 @@ def add_dataset(request: FixtureRequest, client: RAGFlow): | |||
| client.delete_datasets(ids=None) | |||
| request.addfinalizer(cleanup) | |||
| dataset_ids = batch_create_datasets(client, 1) | |||
| return dataset_ids[0] | |||
| return batch_create_datasets(client, 1)[0] | |||
| @pytest.fixture(scope="function") | |||
| @@ -111,12 +109,31 @@ def add_dataset_func(request: FixtureRequest, client: RAGFlow) -> DataSet: | |||
| @pytest.fixture(scope="class") | |||
| def add_document(request: FixtureRequest, add_dataset: DataSet, ragflow_tmp_dir): | |||
| dataset = add_dataset | |||
| documents = bulk_upload_documents(dataset, 1, ragflow_tmp_dir) | |||
| def add_document(add_dataset: DataSet, ragflow_tmp_dir: Path) -> tuple[DataSet, Document]: | |||
| return add_dataset, bulk_upload_documents(add_dataset, 1, ragflow_tmp_dir)[0] | |||
| @pytest.fixture(scope="class") | |||
| def add_chunks(request: FixtureRequest, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chunk]]: | |||
| dataset, document = add_document | |||
| dataset.async_parse_documents([document.id]) | |||
| condition(dataset) | |||
| chunks = [] | |||
| for i in range(4): | |||
| chunk = document.add_chunk(content=f"chunk test {i}") | |||
| chunks.append(chunk) | |||
| # issues/6487 | |||
| from time import sleep | |||
| sleep(1) | |||
| def cleanup(): | |||
| dataset.delete_documents(ids=None) | |||
| try: | |||
| document.delete_chunks(ids=[]) | |||
| except Exception: | |||
| pass | |||
| request.addfinalizer(cleanup) | |||
| return dataset, documents[0] | |||
| return dataset, document, chunks | |||
| @@ -0,0 +1,49 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import pytest | |||
| from pytest import FixtureRequest | |||
| from ragflow_sdk import Chunk, DataSet, Document | |||
| from utils import wait_for | |||
| @wait_for(30, 1, "Document parsing timeout") | |||
| def condition(_dataset: DataSet): | |||
| documents = _dataset.list_documents(page_size=1000) | |||
| for document in documents: | |||
| if document.run != "DONE": | |||
| return False | |||
| return True | |||
| @pytest.fixture(scope="function") | |||
| def add_chunks_func(request: FixtureRequest, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chunk]]: | |||
| dataset, document = add_document | |||
| dataset.async_parse_documents([document.id]) | |||
| condition(dataset) | |||
| chunks = [document.add_chunk(content=f"chunk test {i}") for i in range(4)] | |||
| # issues/6487 | |||
| from time import sleep | |||
| sleep(1) | |||
| def cleanup(): | |||
| document.delete_chunks(ids=[]) | |||
| request.addfinalizer(cleanup) | |||
| return dataset, document, chunks | |||
| @@ -0,0 +1,160 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| from time import sleep | |||
| import pytest | |||
| from ragflow_sdk import Chunk | |||
| def validate_chunk_details(dataset_id: str, document_id: str, payload: dict, chunk: Chunk): | |||
| assert chunk.dataset_id == dataset_id | |||
| assert chunk.document_id == document_id | |||
| assert chunk.content == payload["content"] | |||
| if "important_keywords" in payload: | |||
| assert chunk.important_keywords == payload["important_keywords"] | |||
| if "questions" in payload: | |||
| assert chunk.questions == [str(q).strip() for q in payload.get("questions", []) if str(q).strip()] | |||
| class TestAddChunk: | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message", | |||
| [ | |||
| ({"content": None}, "not instance of"), | |||
| ({"content": ""}, "`content` is required"), | |||
| ({"content": 1}, "not instance of"), | |||
| ({"content": "a"}, ""), | |||
| ({"content": " "}, "`content` is required"), | |||
| ({"content": "\n!?。;!?\"'"}, ""), | |||
| ], | |||
| ) | |||
| def test_content(self, add_document, payload, expected_message): | |||
| dataset, document = add_document | |||
| chunks_count = len(document.list_chunks()) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| document.add_chunk(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunk = document.add_chunk(**payload) | |||
| validate_chunk_details(dataset.id, document.id, payload, chunk) | |||
| sleep(1) | |||
| chunks = document.list_chunks() | |||
| assert len(chunks) == chunks_count + 1, str(chunks) | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message", | |||
| [ | |||
| ({"content": "chunk test important_keywords 1", "important_keywords": ["a", "b", "c"]}, ""), | |||
| ({"content": "chunk test important_keywords 2", "important_keywords": [""]}, ""), | |||
| ({"content": "chunk test important_keywords 3", "important_keywords": [1]}, "not instance of"), | |||
| ({"content": "chunk test important_keywords 4", "important_keywords": ["a", "a"]}, ""), | |||
| ({"content": "chunk test important_keywords 5", "important_keywords": "abc"}, "not instance of"), | |||
| ({"content": "chunk test important_keywords 6", "important_keywords": 123}, "not instance of"), | |||
| ], | |||
| ) | |||
| def test_important_keywords(self, add_document, payload, expected_message): | |||
| dataset, document = add_document | |||
| chunks_count = len(document.list_chunks()) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| document.add_chunk(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunk = document.add_chunk(**payload) | |||
| validate_chunk_details(dataset.id, document.id, payload, chunk) | |||
| sleep(1) | |||
| chunks = document.list_chunks() | |||
| assert len(chunks) == chunks_count + 1, str(chunks) | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message", | |||
| [ | |||
| ({"content": "chunk test test_questions 1", "questions": ["a", "b", "c"]}, ""), | |||
| ({"content": "chunk test test_questions 2", "questions": [""]}, ""), | |||
| ({"content": "chunk test test_questions 3", "questions": [1]}, "not instance of"), | |||
| ({"content": "chunk test test_questions 4", "questions": ["a", "a"]}, ""), | |||
| ({"content": "chunk test test_questions 5", "questions": "abc"}, "not instance of"), | |||
| ({"content": "chunk test test_questions 6", "questions": 123}, "not instance of"), | |||
| ], | |||
| ) | |||
| def test_questions(self, add_document, payload, expected_message): | |||
| dataset, document = add_document | |||
| chunks_count = len(document.list_chunks()) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| document.add_chunk(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunk = document.add_chunk(**payload) | |||
| validate_chunk_details(dataset.id, document.id, payload, chunk) | |||
| sleep(1) | |||
| chunks = document.list_chunks() | |||
| assert len(chunks) == chunks_count + 1, str(chunks) | |||
| @pytest.mark.p3 | |||
| def test_repeated_add_chunk(self, add_document): | |||
| payload = {"content": "chunk test repeated_add_chunk"} | |||
| dataset, document = add_document | |||
| chunks_count = len(document.list_chunks()) | |||
| chunk1 = document.add_chunk(**payload) | |||
| validate_chunk_details(dataset.id, document.id, payload, chunk1) | |||
| sleep(1) | |||
| chunks = document.list_chunks() | |||
| assert len(chunks) == chunks_count + 1, str(chunks) | |||
| chunk2 = document.add_chunk(**payload) | |||
| validate_chunk_details(dataset.id, document.id, payload, chunk2) | |||
| sleep(1) | |||
| chunks = document.list_chunks() | |||
| assert len(chunks) == chunks_count + 1, str(chunks) | |||
| @pytest.mark.p2 | |||
| def test_add_chunk_to_deleted_document(self, add_document): | |||
| dataset, document = add_document | |||
| dataset.delete_documents(ids=[document.id]) | |||
| with pytest.raises(Exception) as excinfo: | |||
| document.add_chunk(content="chunk test") | |||
| assert f"You don't own the document {document.id}" in str(excinfo.value), str(excinfo.value) | |||
| @pytest.mark.skip(reason="issues/6411") | |||
| @pytest.mark.p3 | |||
| def test_concurrent_add_chunk(self, add_document): | |||
| count = 50 | |||
| _, document = add_document | |||
| initial_chunk_count = len(document.list_chunks()) | |||
| def add_chunk_task(i): | |||
| return document.add_chunk(content=f"chunk test concurrent {i}") | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(add_chunk_task, i) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| sleep(5) | |||
| assert len(document.list_chunks(page_size=100)) == initial_chunk_count + count | |||
| @@ -0,0 +1,113 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import batch_add_chunks | |||
| class TestChunksDeletion: | |||
| @pytest.mark.parametrize( | |||
| "payload", | |||
| [ | |||
| pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3), | |||
| pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1), | |||
| pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3), | |||
| ], | |||
| ) | |||
| def test_delete_partial_invalid_id(self, add_chunks_func, payload): | |||
| _, document, chunks = add_chunks_func | |||
| chunk_ids = [chunk.id for chunk in chunks] | |||
| payload = payload(chunk_ids) | |||
| with pytest.raises(Exception) as excinfo: | |||
| document.delete_chunks(**payload) | |||
| assert "rm_chunk deleted chunks" in str(excinfo.value), str(excinfo.value) | |||
| remaining_chunks = document.list_chunks() | |||
| assert len(remaining_chunks) == 1, str(remaining_chunks) | |||
| @pytest.mark.p3 | |||
| def test_repeated_deletion(self, add_chunks_func): | |||
| _, document, chunks = add_chunks_func | |||
| chunk_ids = [chunk.id for chunk in chunks] | |||
| document.delete_chunks(ids=chunk_ids) | |||
| with pytest.raises(Exception) as excinfo: | |||
| document.delete_chunks(ids=chunk_ids) | |||
| assert "rm_chunk deleted chunks 0, expect" in str(excinfo.value), str(excinfo.value) | |||
| @pytest.mark.p3 | |||
| def test_duplicate_deletion(self, add_chunks_func): | |||
| _, document, chunks = add_chunks_func | |||
| chunk_ids = [chunk.id for chunk in chunks] | |||
| document.delete_chunks(ids=chunk_ids * 2) | |||
| remaining_chunks = document.list_chunks() | |||
| assert len(remaining_chunks) == 1, str(remaining_chunks) | |||
| @pytest.mark.p3 | |||
| def test_concurrent_deletion(self, add_document): | |||
| count = 100 | |||
| _, document = add_document | |||
| chunks = batch_add_chunks(document, count) | |||
| chunk_ids = [chunk.id for chunk in chunks] | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(document.delete_chunks, ids=[chunk_id]) for chunk_id in chunk_ids] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| @pytest.mark.p3 | |||
| def test_delete_1k(self, add_document): | |||
| count = 1_000 | |||
| _, document = add_document | |||
| chunks = batch_add_chunks(document, count) | |||
| chunk_ids = [chunk.id for chunk in chunks] | |||
| from time import sleep | |||
| sleep(1) | |||
| document.delete_chunks(ids=chunk_ids) | |||
| remaining_chunks = document.list_chunks() | |||
| assert len(remaining_chunks) == 0, str(remaining_chunks) | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message, remaining", | |||
| [ | |||
| pytest.param(None, "TypeError", 5, marks=pytest.mark.skip), | |||
| pytest.param({"ids": ["invalid_id"]}, "rm_chunk deleted chunks 0, expect 1", 5, marks=pytest.mark.p3), | |||
| pytest.param("not json", "UnboundLocalError", 5, marks=pytest.mark.skip(reason="pull/6376")), | |||
| pytest.param(lambda r: {"ids": r[:1]}, "", 4, marks=pytest.mark.p3), | |||
| pytest.param(lambda r: {"ids": r}, "", 1, marks=pytest.mark.p1), | |||
| pytest.param({"ids": []}, "", 0, marks=pytest.mark.p3), | |||
| ], | |||
| ) | |||
| def test_basic_scenarios(self, add_chunks_func, payload, expected_message, remaining): | |||
| _, document, chunks = add_chunks_func | |||
| chunk_ids = [chunk.id for chunk in chunks] | |||
| if callable(payload): | |||
| payload = payload(chunk_ids) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| document.delete_chunks(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| document.delete_chunks(**payload) | |||
| remaining_chunks = document.list_chunks() | |||
| assert len(remaining_chunks) == remaining, str(remaining_chunks) | |||
| @@ -0,0 +1,140 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import batch_add_chunks | |||
| class TestChunksList: | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_page_size, expected_message", | |||
| [ | |||
| ({"page": None, "page_size": 2}, 2, ""), | |||
| pytest.param({"page": 0, "page_size": 2}, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), | |||
| ({"page": 2, "page_size": 2}, 2, ""), | |||
| ({"page": 3, "page_size": 2}, 1, ""), | |||
| ({"page": "3", "page_size": 2}, 1, ""), | |||
| pytest.param({"page": -1, "page_size": 2}, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), | |||
| pytest.param({"page": "a", "page_size": 2}, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_page(self, add_chunks, params, expected_page_size, expected_message): | |||
| _, document, _ = add_chunks | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| document.list_chunks(**params) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunks = document.list_chunks(**params) | |||
| assert len(chunks) == expected_page_size, str(chunks) | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_page_size, expected_message", | |||
| [ | |||
| ({"page_size": None}, 5, ""), | |||
| pytest.param({"page_size": 0}, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")), | |||
| pytest.param({"page_size": 0}, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")), | |||
| ({"page_size": 1}, 1, ""), | |||
| ({"page_size": 6}, 5, ""), | |||
| ({"page_size": "1"}, 1, ""), | |||
| pytest.param({"page_size": -1}, 5, "", marks=pytest.mark.skip), | |||
| pytest.param({"page_size": "a"}, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_page_size(self, add_chunks, params, expected_page_size, expected_message): | |||
| _, document, _ = add_chunks | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| document.list_chunks(**params) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunks = document.list_chunks(**params) | |||
| assert len(chunks) == expected_page_size, str(chunks) | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_page_size", | |||
| [ | |||
| ({"keywords": None}, 5), | |||
| ({"keywords": ""}, 5), | |||
| ({"keywords": "1"}, 1), | |||
| pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")), | |||
| ({"keywords": "ragflow"}, 1), | |||
| ({"keywords": "unknown"}, 0), | |||
| ], | |||
| ) | |||
| def test_keywords(self, add_chunks, params, expected_page_size): | |||
| _, document, _ = add_chunks | |||
| chunks = document.list_chunks(**params) | |||
| assert len(chunks) == expected_page_size, str(chunks) | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "chunk_id, expected_page_size, expected_message", | |||
| [ | |||
| (None, 5, ""), | |||
| ("", 5, ""), | |||
| pytest.param(lambda r: r[0], 1, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6499")), | |||
| pytest.param("unknown", 0, """AttributeError("\'NoneType\' object has no attribute \'keys\'")""", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_id(self, add_chunks, chunk_id, expected_page_size, expected_message): | |||
| _, document, chunks = add_chunks | |||
| chunk_ids = [chunk.id for chunk in chunks] | |||
| if callable(chunk_id): | |||
| params = {"id": chunk_id(chunk_ids)} | |||
| else: | |||
| params = {"id": chunk_id} | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| document.list_chunks(**params) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunks = document.list_chunks(**params) | |||
| if params["id"] in [None, ""]: | |||
| assert len(chunks) == expected_page_size, str(chunks) | |||
| else: | |||
| assert chunks[0].id == params["id"], str(chunks) | |||
| @pytest.mark.p3 | |||
| def test_concurrent_list(self, add_chunks): | |||
| _, document, _ = add_chunks | |||
| count = 100 | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(document.list_chunks) for _ in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(len(future.result()) == 5 for future in futures) | |||
| @pytest.mark.p1 | |||
| def test_default(self, add_document): | |||
| _, document = add_document | |||
| batch_add_chunks(document, 31) | |||
| from time import sleep | |||
| sleep(3) | |||
| chunks = document.list_chunks() | |||
| assert len(chunks) == 30, str(chunks) | |||
| @@ -0,0 +1,254 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| class TestChunksRetrieval: | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_page_size, expected_message", | |||
| [ | |||
| ({"question": "chunk", "dataset_ids": None}, 4, ""), | |||
| ({"question": "chunk", "document_ids": None}, 0, "missing 1 required positional argument"), | |||
| ({"question": "chunk", "dataset_ids": None, "document_ids": None}, 4, ""), | |||
| ({"question": "chunk"}, 0, "missing 1 required positional argument"), | |||
| ], | |||
| ) | |||
| def test_basic_scenarios(self, client, add_chunks, payload, expected_page_size, expected_message): | |||
| dataset, document, _ = add_chunks | |||
| if "dataset_ids" in payload: | |||
| payload["dataset_ids"] = [dataset.id] | |||
| if "document_ids" in payload: | |||
| payload["document_ids"] = [document.id] | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.retrieve(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunks = client.retrieve(**payload) | |||
| assert len(chunks) == expected_page_size, str(chunks) | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_page_size, expected_message", | |||
| [ | |||
| pytest.param( | |||
| {"page": None, "page_size": 2}, | |||
| 2, | |||
| """TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| pytest.param( | |||
| {"page": 0, "page_size": 2}, | |||
| 0, | |||
| "ValueError('Search does not support negative slicing.')", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| pytest.param({"page": 2, "page_size": 2}, 2, "", marks=pytest.mark.skip(reason="issues/6646")), | |||
| ({"page": 3, "page_size": 2}, 0, ""), | |||
| ({"page": "3", "page_size": 2}, 0, ""), | |||
| pytest.param( | |||
| {"page": -1, "page_size": 2}, | |||
| 0, | |||
| "ValueError('Search does not support negative slicing.')", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| pytest.param( | |||
| {"page": "a", "page_size": 2}, | |||
| 0, | |||
| """ValueError("invalid literal for int() with base 10: \'a\'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| ], | |||
| ) | |||
| def test_page(self, client, add_chunks, payload, expected_page_size, expected_message): | |||
| dataset, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.retrieve(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunks = client.retrieve(**payload) | |||
| assert len(chunks) == expected_page_size, str(chunks) | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_page_size, expected_message", | |||
| [ | |||
| pytest.param( | |||
| {"page_size": None}, | |||
| 0, | |||
| """TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| ({"page_size": 1}, 1, ""), | |||
| ({"page_size": 5}, 4, ""), | |||
| ({"page_size": "1"}, 1, ""), | |||
| pytest.param( | |||
| {"page_size": "a"}, | |||
| 0, | |||
| """ValueError("invalid literal for int() with base 10: \'a\'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| ], | |||
| ) | |||
| def test_page_size(self, client, add_chunks, payload, expected_page_size, expected_message): | |||
| dataset, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.retrieve(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunks = client.retrieve(**payload) | |||
| assert len(chunks) == expected_page_size, str(chunks) | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_page_size, expected_message", | |||
| [ | |||
| ({"vector_similarity_weight": 0}, 4, ""), | |||
| ({"vector_similarity_weight": 0.5}, 4, ""), | |||
| ({"vector_similarity_weight": 10}, 4, ""), | |||
| pytest.param( | |||
| {"vector_similarity_weight": "a"}, | |||
| 0, | |||
| """ValueError("could not convert string to float: 'a'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| ], | |||
| ) | |||
| def test_vector_similarity_weight(self, client, add_chunks, payload, expected_page_size, expected_message): | |||
| dataset, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.retrieve(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunks = client.retrieve(**payload) | |||
| assert len(chunks) == expected_page_size, str(chunks) | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_page_size, expected_message", | |||
| [ | |||
| ({"top_k": 10}, 4, ""), | |||
| pytest.param( | |||
| {"top_k": 1}, | |||
| 4, | |||
| "", | |||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), | |||
| ), | |||
| pytest.param( | |||
| {"top_k": 1}, | |||
| 1, | |||
| "", | |||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), | |||
| ), | |||
| pytest.param( | |||
| {"top_k": -1}, | |||
| 4, | |||
| "must be greater than 0", | |||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), | |||
| ), | |||
| pytest.param( | |||
| {"top_k": -1}, | |||
| 4, | |||
| "3014", | |||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), | |||
| ), | |||
| pytest.param( | |||
| {"top_k": "a"}, | |||
| 0, | |||
| """ValueError("invalid literal for int() with base 10: \'a\'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| ], | |||
| ) | |||
| def test_top_k(self, client, add_chunks, payload, expected_page_size, expected_message): | |||
| dataset, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.retrieve(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunks = client.retrieve(**payload) | |||
| assert len(chunks) == expected_page_size, str(chunks) | |||
| @pytest.mark.skip | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message", | |||
| [ | |||
| ({"rerank_id": "BAAI/bge-reranker-v2-m3"}, ""), | |||
| pytest.param({"rerank_id": "unknown"}, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_rerank_id(self, client, add_chunks, payload, expected_message): | |||
| dataset, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.retrieve(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunks = client.retrieve(**payload) | |||
| assert len(chunks) > 0, str(chunks) | |||
| @pytest.mark.skip | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_page_size, expected_message", | |||
| [ | |||
| ({"keyword": True}, 5, ""), | |||
| ({"keyword": "True"}, 5, ""), | |||
| ({"keyword": False}, 5, ""), | |||
| ({"keyword": "False"}, 5, ""), | |||
| ({"keyword": None}, 5, ""), | |||
| ], | |||
| ) | |||
| def test_keyword(self, client, add_chunks, payload, expected_page_size, expected_message): | |||
| dataset, _, _ = add_chunks | |||
| payload.update({"question": "chunk test", "dataset_ids": [dataset.id]}) | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| client.retrieve(**payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunks = client.retrieve(**payload) | |||
| assert len(chunks) == expected_page_size, str(chunks) | |||
| @pytest.mark.p3 | |||
| def test_concurrent_retrieval(self, client, add_chunks): | |||
| dataset, _, _ = add_chunks | |||
| count = 100 | |||
| payload = {"question": "chunk", "dataset_ids": [dataset.id]} | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(client.retrieve, **payload) for _ in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| @@ -0,0 +1,154 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| from random import randint | |||
| import pytest | |||
| class TestUpdatedChunk: | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message", | |||
| [ | |||
| ({"content": None}, "TypeError('expected string or bytes-like object')"), | |||
| pytest.param( | |||
| {"content": ""}, | |||
| """APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""", | |||
| marks=pytest.mark.skip(reason="issues/6541"), | |||
| ), | |||
| pytest.param( | |||
| {"content": 1}, | |||
| "TypeError('expected string or bytes-like object')", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| ({"content": "update chunk"}, ""), | |||
| pytest.param( | |||
| {"content": " "}, | |||
| """APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""", | |||
| marks=pytest.mark.skip(reason="issues/6541"), | |||
| ), | |||
| ({"content": "\n!?。;!?\"'"}, ""), | |||
| ], | |||
| ) | |||
| def test_content(self, add_chunks, payload, expected_message): | |||
| _, _, chunks = add_chunks | |||
| chunk = chunks[0] | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| chunk.update(payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunk.update(payload) | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message", | |||
| [ | |||
| ({"important_keywords": ["a", "b", "c"]}, ""), | |||
| ({"important_keywords": [""]}, ""), | |||
| ({"important_keywords": [1]}, "TypeError('sequence item 0: expected str instance, int found')"), | |||
| ({"important_keywords": ["a", "a"]}, ""), | |||
| ({"important_keywords": "abc"}, "`important_keywords` should be a list"), | |||
| ({"important_keywords": 123}, "`important_keywords` should be a list"), | |||
| ], | |||
| ) | |||
| def test_important_keywords(self, add_chunks, payload, expected_message): | |||
| _, _, chunks = add_chunks | |||
| chunk = chunks[0] | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| chunk.update(payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunk.update(payload) | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message", | |||
| [ | |||
| ({"questions": ["a", "b", "c"]}, ""), | |||
| ({"questions": [""]}, ""), | |||
| ({"questions": [1]}, "TypeError('sequence item 0: expected str instance, int found')"), | |||
| ({"questions": ["a", "a"]}, ""), | |||
| ({"questions": "abc"}, "`questions` should be a list"), | |||
| ({"questions": 123}, "`questions` should be a list"), | |||
| ], | |||
| ) | |||
| def test_questions(self, add_chunks, payload, expected_message): | |||
| _, _, chunks = add_chunks | |||
| chunk = chunks[0] | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| chunk.update(payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunk.update(payload) | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_message", | |||
| [ | |||
| ({"available": True}, ""), | |||
| pytest.param({"available": "True"}, """ValueError("invalid literal for int() with base 10: \'True\'")""", marks=pytest.mark.skip), | |||
| ({"available": 1}, ""), | |||
| ({"available": False}, ""), | |||
| pytest.param({"available": "False"}, """ValueError("invalid literal for int() with base 10: \'False\'")""", marks=pytest.mark.skip), | |||
| ({"available": 0}, ""), | |||
| ], | |||
| ) | |||
| def test_available(self, add_chunks, payload, expected_message): | |||
| _, _, chunks = add_chunks | |||
| chunk = chunks[0] | |||
| if expected_message: | |||
| with pytest.raises(Exception) as excinfo: | |||
| chunk.update(payload) | |||
| assert expected_message in str(excinfo.value), str(excinfo.value) | |||
| else: | |||
| chunk.update(payload) | |||
| @pytest.mark.p3 | |||
| def test_repeated_update_chunk(self, add_chunks): | |||
| _, _, chunks = add_chunks | |||
| chunk = chunks[0] | |||
| chunk.update({"content": "chunk test 1"}) | |||
| chunk.update({"content": "chunk test 2"}) | |||
| @pytest.mark.p3 | |||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554") | |||
| def test_concurrent_update_chunk(self, add_chunks): | |||
| count = 50 | |||
| _, _, chunks = add_chunks | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(chunks[randint(0, 3)].update, {"content": f"update chunk test {i}"}) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| @pytest.mark.p3 | |||
| def test_update_chunk_to_deleted_document(self, add_chunks): | |||
| dataset, document, chunks = add_chunks | |||
| dataset.delete_documents(ids=[document.id]) | |||
| with pytest.raises(Exception) as excinfo: | |||
| chunks[0].update({}) | |||
| assert f"Can't find this chunk {chunks[0].id}" in str(excinfo.value), str(excinfo.value) | |||