### What problem does this PR solve? - Add comprehensive test suite for chunk operations including: - Test files for create, list, retrieve, update, and delete chunks - Authorization tests - Batch operations tests - Update test configurations and common utilities - Validate `important_kwd` and `question_kwd` fields are lists in chunk_app.py - Reorganize imports and clean up duplicate code ### Type of change - [x] Add test casestags/v0.20.0
| @@ -15,27 +15,25 @@ | |||
| # | |||
| import datetime | |||
| import json | |||
| import re | |||
| import xxhash | |||
| from flask import request | |||
| from flask_login import login_required, current_user | |||
| from flask_login import current_user, login_required | |||
| from rag.app.qa import rmPrefix, beAdoc | |||
| from rag.app.tag import label_question | |||
| from rag.nlp import search, rag_tokenizer | |||
| from rag.prompts import keyword_extraction, cross_languages | |||
| from rag.settings import PAGERANK_FLD | |||
| from rag.utils import rmSpace | |||
| from api import settings | |||
| from api.db import LLMType, ParserType | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.user_service import UserTenantService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.db.services.document_service import DocumentService | |||
| from api import settings | |||
| from api.utils.api_utils import get_json_result | |||
| import xxhash | |||
| import re | |||
| from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request | |||
| from rag.app.qa import beAdoc, rmPrefix | |||
| from rag.app.tag import label_question | |||
| from rag.nlp import rag_tokenizer, search | |||
| from rag.prompts import cross_languages, keyword_extraction | |||
| from rag.settings import PAGERANK_FLD | |||
| from rag.utils import rmSpace | |||
| @manager.route('/list', methods=['POST']) # noqa: F821 | |||
| @@ -129,9 +127,13 @@ def set(): | |||
| d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) | |||
| d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | |||
| if "important_kwd" in req: | |||
| if not isinstance(req["important_kwd"], list): | |||
| return get_data_error_result(message="`important_kwd` should be a list") | |||
| d["important_kwd"] = req["important_kwd"] | |||
| d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) | |||
| if "question_kwd" in req: | |||
| if not isinstance(req["question_kwd"], list): | |||
| return get_data_error_result(message="`question_kwd` should be a list") | |||
| d["question_kwd"] = req["question_kwd"] | |||
| d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) | |||
| if "tag_kwd" in req: | |||
| @@ -235,6 +237,8 @@ def create(): | |||
| d["create_timestamp_flt"] = datetime.datetime.now().timestamp() | |||
| if "tag_feas" in req: | |||
| d["tag_feas"] = req["tag_feas"] | |||
| if "tag_feas" in req: | |||
| d["tag_feas"] = req["tag_feas"] | |||
| try: | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| @@ -13,7 +13,6 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from time import sleep | |||
| import pytest | |||
| @@ -24,9 +24,7 @@ HEADERS = {"Content-Type": "application/json"} | |||
| KB_APP_URL = f"/{VERSION}/kb" | |||
| DOCUMENT_APP_URL = f"/{VERSION}/document" | |||
| # FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents" | |||
| # FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks" | |||
| # CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks" | |||
| CHUNK_API_URL = f"/{VERSION}/chunk" | |||
| # CHAT_ASSISTANT_API_URL = "/api/v1/chats" | |||
| # SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions" | |||
| # SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions" | |||
| @@ -164,3 +162,42 @@ def bulk_upload_documents(auth, kb_id, num, tmp_path): | |||
| for document in res["data"]: | |||
| document_ids.append(document["id"]) | |||
| return document_ids | |||
| # CHUNK APP | |||
| def add_chunk(auth, payload=None, *, headers=HEADERS, data=None): | |||
| res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/create", headers=headers, auth=auth, json=payload, data=data) | |||
| return res.json() | |||
| def list_chunks(auth, payload=None, *, headers=HEADERS): | |||
| res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/list", headers=headers, auth=auth, json=payload) | |||
| return res.json() | |||
| def get_chunk(auth, params=None, *, headers=HEADERS): | |||
| res = requests.get(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/get", headers=headers, auth=auth, params=params) | |||
| return res.json() | |||
| def update_chunk(auth, payload=None, *, headers=HEADERS): | |||
| res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/set", headers=headers, auth=auth, json=payload) | |||
| return res.json() | |||
| def delete_chunks(auth, payload=None, *, headers=HEADERS): | |||
| res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/rm", headers=headers, auth=auth, json=payload) | |||
| return res.json() | |||
| def retrieval_chunks(auth, payload=None, *, headers=HEADERS): | |||
| res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/retrieval_test", headers=headers, auth=auth, json=payload) | |||
| return res.json() | |||
| def batch_add_chunks(auth, doc_id, num): | |||
| chunk_ids = [] | |||
| for i in range(num): | |||
| res = add_chunk(auth, {"doc_id": doc_id, "content_with_weight": f"chunk test {i}"}) | |||
| chunk_ids.append(res["data"]["chunk_id"]) | |||
| return chunk_ids | |||
| @@ -13,18 +13,23 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from time import sleep | |||
| import pytest | |||
| from common import ( | |||
| batch_add_chunks, | |||
| batch_create_datasets, | |||
| bulk_upload_documents, | |||
| delete_chunks, | |||
| list_chunks, | |||
| list_documents, | |||
| list_kbs, | |||
| parse_documents, | |||
| rm_kb, | |||
| ) | |||
| # from configs import HOST_ADDRESS, VERSION | |||
| from libs.auth import RAGFlowWebApiAuth | |||
| from pytest import FixtureRequest | |||
| # from ragflow_sdk import RAGFlow | |||
| from utils import wait_for | |||
| from utils.file_utils import ( | |||
| create_docx_file, | |||
| create_eml_file, | |||
| @@ -39,6 +44,15 @@ from utils.file_utils import ( | |||
| ) | |||
| @wait_for(30, 1, "Document parsing timeout") | |||
| def condition(_auth, _kb_id): | |||
| res = list_documents(_auth, {"kb_id": _kb_id}) | |||
| for doc in res["data"]["docs"]: | |||
| if doc["run"] != "3": | |||
| return False | |||
| return True | |||
| @pytest.fixture | |||
| def generate_test_files(request: FixtureRequest, tmp_path): | |||
| file_creators = { | |||
| @@ -73,11 +87,6 @@ def WebApiAuth(auth): | |||
| return RAGFlowWebApiAuth(auth) | |||
| # @pytest.fixture(scope="session") | |||
| # def client(token: str) -> RAGFlow: | |||
| # return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION) | |||
| @pytest.fixture(scope="function") | |||
| def clear_datasets(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth): | |||
| def cleanup(): | |||
| @@ -108,3 +117,35 @@ def add_dataset_func(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) -> | |||
| request.addfinalizer(cleanup) | |||
| return batch_create_datasets(WebApiAuth, 1)[0] | |||
| @pytest.fixture(scope="class") | |||
| def add_document(request, WebApiAuth, add_dataset, ragflow_tmp_dir): | |||
| # def cleanup(): | |||
| # res = list_documents(WebApiAuth, {"kb_id": dataset_id}) | |||
| # for doc in res["data"]["docs"]: | |||
| # delete_document(WebApiAuth, {"doc_id": doc["id"]}) | |||
| # request.addfinalizer(cleanup) | |||
| dataset_id = add_dataset | |||
| return dataset_id, bulk_upload_documents(WebApiAuth, dataset_id, 1, ragflow_tmp_dir)[0] | |||
| @pytest.fixture(scope="class") | |||
| def add_chunks(request, WebApiAuth, add_document): | |||
| def cleanup(): | |||
| res = list_chunks(WebApiAuth, {"doc_id": document_id}) | |||
| if res["code"] == 0: | |||
| chunk_ids = [chunk["chunk_id"] for chunk in res["data"]["chunks"]] | |||
| delete_chunks(WebApiAuth, {"doc_id": document_id, "chunk_ids": chunk_ids}) | |||
| request.addfinalizer(cleanup) | |||
| kb_id, document_id = add_document | |||
| parse_documents(WebApiAuth, {"doc_ids": [document_id], "run": "1"}) | |||
| condition(WebApiAuth, kb_id) | |||
| chunk_ids = batch_add_chunks(WebApiAuth, document_id, 4) | |||
| # issues/6487 | |||
| sleep(1) | |||
| return kb_id, document_id, chunk_ids | |||
| @@ -0,0 +1,49 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from time import sleep | |||
| import pytest | |||
| from common import batch_add_chunks, delete_chunks, list_chunks, list_documents, parse_documents | |||
| from utils import wait_for | |||
| @wait_for(30, 1, "Document parsing timeout") | |||
| def condition(_auth, _kb_id): | |||
| res = list_documents(_auth, {"kb_id": _kb_id}) | |||
| for doc in res["data"]["docs"]: | |||
| if doc["run"] != "3": | |||
| return False | |||
| return True | |||
| @pytest.fixture(scope="function") | |||
| def add_chunks_func(request, WebApiAuth, add_document): | |||
| def cleanup(): | |||
| res = list_chunks(WebApiAuth, {"doc_id": document_id}) | |||
| chunk_ids = [chunk["chunk_id"] for chunk in res["data"]["chunks"]] | |||
| delete_chunks(WebApiAuth, {"doc_id": document_id, "chunk_ids": chunk_ids}) | |||
| request.addfinalizer(cleanup) | |||
| kb_id, document_id = add_document | |||
| parse_documents(WebApiAuth, {"doc_ids": [document_id], "run": "1"}) | |||
| condition(WebApiAuth, kb_id) | |||
| chunk_ids = batch_add_chunks(WebApiAuth, document_id, 4) | |||
| # issues/6487 | |||
| sleep(1) | |||
| return kb_id, document_id, chunk_ids | |||
| @@ -0,0 +1,223 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import add_chunk, delete_document, get_chunk, list_chunks | |||
| from configs import INVALID_API_TOKEN | |||
| from libs.auth import RAGFlowWebApiAuth | |||
| def validate_chunk_details(auth, kb_id, doc_id, payload, res): | |||
| chunk_id = res["data"]["chunk_id"] | |||
| res = get_chunk(auth, {"chunk_id": chunk_id}) | |||
| assert res["code"] == 0, res | |||
| chunk = res["data"] | |||
| assert chunk["doc_id"] == doc_id | |||
| assert chunk["kb_id"] == kb_id | |||
| assert chunk["content_with_weight"] == payload["content_with_weight"] | |||
| if "important_kwd" in payload: | |||
| assert chunk["important_kwd"] == payload["important_kwd"] | |||
| if "question_kwd" in payload: | |||
| expected = [str(q).strip() for q in payload.get("question_kwd", [])] | |||
| assert chunk["question_kwd"] == expected | |||
| @pytest.mark.p1 | |||
| class TestAuthorization: | |||
| @pytest.mark.parametrize( | |||
| "invalid_auth, expected_code, expected_message", | |||
| [ | |||
| (None, 401, "<Unauthorized '401: Unauthorized'>"), | |||
| (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, invalid_auth, expected_code, expected_message): | |||
| res = add_chunk(invalid_auth) | |||
| assert res["code"] == expected_code, res | |||
| assert res["message"] == expected_message, res | |||
| class TestAddChunk: | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message", | |||
| [ | |||
| ({"content_with_weight": None}, 100, """TypeError("unsupported operand type(s) for +: 'NoneType' and 'str'")"""), | |||
| ({"content_with_weight": ""}, 0, ""), | |||
| pytest.param( | |||
| {"content_with_weight": 1}, | |||
| 100, | |||
| """TypeError("unsupported operand type(s) for +: 'int' and 'str'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| ({"content_with_weight": "a"}, 0, ""), | |||
| ({"content_with_weight": " "}, 0, ""), | |||
| ({"content_with_weight": "\n!?。;!?\"'"}, 0, ""), | |||
| ], | |||
| ) | |||
| def test_content(self, WebApiAuth, add_document, payload, expected_code, expected_message): | |||
| kb_id, doc_id = add_document | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| if res["code"] == 0: | |||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||
| else: | |||
| chunks_count = 0 | |||
| res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| assert res["code"] == 0, res | |||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res | |||
| else: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message", | |||
| [ | |||
| ({"content_with_weight": "chunk test", "important_kwd": ["a", "b", "c"]}, 0, ""), | |||
| ({"content_with_weight": "chunk test", "important_kwd": [""]}, 0, ""), | |||
| ( | |||
| {"content_with_weight": "chunk test", "important_kwd": [1]}, | |||
| 100, | |||
| "TypeError('sequence item 0: expected str instance, int found')", | |||
| ), | |||
| ({"content_with_weight": "chunk test", "important_kwd": ["a", "a"]}, 0, ""), | |||
| ({"content_with_weight": "chunk test", "important_kwd": "abc"}, 102, "`important_kwd` is required to be a list"), | |||
| ({"content_with_weight": "chunk test", "important_kwd": 123}, 102, "`important_kwd` is required to be a list"), | |||
| ], | |||
| ) | |||
| def test_important_keywords(self, WebApiAuth, add_document, payload, expected_code, expected_message): | |||
| kb_id, doc_id = add_document | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| if res["code"] == 0: | |||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||
| else: | |||
| chunks_count = 0 | |||
| res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| assert res["code"] == 0, res | |||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res | |||
| else: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message", | |||
| [ | |||
| ({"content_with_weight": "chunk test", "question_kwd": ["a", "b", "c"]}, 0, ""), | |||
| ({"content_with_weight": "chunk test", "question_kwd": [""]}, 0, ""), | |||
| ({"content_with_weight": "chunk test", "question_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), | |||
| ({"content_with_weight": "chunk test", "question_kwd": ["a", "a"]}, 0, ""), | |||
| ({"content_with_weight": "chunk test", "question_kwd": "abc"}, 102, "`question_kwd` is required to be a list"), | |||
| ({"content_with_weight": "chunk test", "question_kwd": 123}, 102, "`question_kwd` is required to be a list"), | |||
| ], | |||
| ) | |||
| def test_questions(self, WebApiAuth, add_document, payload, expected_code, expected_message): | |||
| kb_id, doc_id = add_document | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| if res["code"] == 0: | |||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||
| else: | |||
| chunks_count = 0 | |||
| res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| assert res["code"] == 0, res | |||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res | |||
| else: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "doc_id, expected_code, expected_message", | |||
| [ | |||
| ("", 102, "Document not found!"), | |||
| ("invalid_document_id", 102, "Document not found!"), | |||
| ], | |||
| ) | |||
| def test_invalid_document_id(self, WebApiAuth, add_document, doc_id, expected_code, expected_message): | |||
| _, _ = add_document | |||
| res = add_chunk(WebApiAuth, {"doc_id": doc_id, "content_with_weight": "chunk test"}) | |||
| assert res["code"] == expected_code, res | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p3 | |||
| def test_repeated_add_chunk(self, WebApiAuth, add_document): | |||
| payload = {"content_with_weight": "chunk test"} | |||
| kb_id, doc_id = add_document | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| if res["code"] != 0: | |||
| assert False, res | |||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||
| res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) | |||
| assert res["code"] == 0, res | |||
| validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| if res["code"] != 0: | |||
| assert False, res | |||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res | |||
| res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) | |||
| assert res["code"] == 0, res | |||
| validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| if res["code"] != 0: | |||
| assert False, res | |||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 2, res | |||
| @pytest.mark.p2 | |||
| def test_add_chunk_to_deleted_document(self, WebApiAuth, add_document): | |||
| _, doc_id = add_document | |||
| delete_document(WebApiAuth, {"doc_id": doc_id}) | |||
| res = add_chunk(WebApiAuth, {"doc_id": doc_id, "content_with_weight": "chunk test"}) | |||
| assert res["code"] == 102, res | |||
| assert res["message"] == "Document not found!", res | |||
| @pytest.mark.skip(reason="issues/6411") | |||
| @pytest.mark.p3 | |||
| def test_concurrent_add_chunk(self, WebApiAuth, add_document): | |||
| count = 50 | |||
| _, doc_id = add_document | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| if res["code"] == 0: | |||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||
| else: | |||
| chunks_count = 0 | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [ | |||
| executor.submit( | |||
| add_chunk, | |||
| WebApiAuth, | |||
| {"doc_id": doc_id, "content_with_weight": f"chunk test {i}"}, | |||
| ) | |||
| for i in range(count) | |||
| ] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| assert res["code"] == 0, res | |||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + count | |||
| @@ -0,0 +1,145 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import batch_add_chunks, list_chunks | |||
| from configs import INVALID_API_TOKEN | |||
| from libs.auth import RAGFlowWebApiAuth | |||
| @pytest.mark.p1 | |||
| class TestAuthorization: | |||
| @pytest.mark.parametrize( | |||
| "invalid_auth, expected_code, expected_message", | |||
| [ | |||
| (None, 401, "<Unauthorized '401: Unauthorized'>"), | |||
| (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, invalid_auth, expected_code, expected_message): | |||
| res = list_chunks(invalid_auth, {"doc_id": "document_id"}) | |||
| assert res["code"] == expected_code, res | |||
| assert res["message"] == expected_message, res | |||
| class TestChunksList: | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_code, expected_page_size, expected_message", | |||
| [ | |||
| pytest.param({"page": None, "size": 2}, 100, 0, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", marks=pytest.mark.skip), | |||
| pytest.param({"page": 0, "size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), | |||
| ({"page": 2, "size": 2}, 0, 2, ""), | |||
| ({"page": 3, "size": 2}, 0, 1, ""), | |||
| ({"page": "3", "size": 2}, 0, 1, ""), | |||
| pytest.param({"page": -1, "size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), | |||
| pytest.param({"page": "a", "size": 2}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_page(self, WebApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message): | |||
| _, doc_id, _ = add_chunks | |||
| payload = {"doc_id": doc_id} | |||
| if params: | |||
| payload.update(params) | |||
| res = list_chunks(WebApiAuth, payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||
| else: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_code, expected_page_size, expected_message", | |||
| [ | |||
| ({"size": None}, 100, 0, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")"""), | |||
| pytest.param({"size": 0}, 0, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")), | |||
| pytest.param({"size": 0}, 100, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")), | |||
| ({"size": 1}, 0, 1, ""), | |||
| ({"size": 6}, 0, 5, ""), | |||
| ({"size": "1"}, 0, 1, ""), | |||
| pytest.param({"size": -1}, 0, 5, "", marks=pytest.mark.skip), | |||
| pytest.param({"size": "a"}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_page_size(self, WebApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message): | |||
| _, doc_id, _ = add_chunks | |||
| payload = {"doc_id": doc_id} | |||
| if params: | |||
| payload.update(params) | |||
| res = list_chunks(WebApiAuth, payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||
| else: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "params, expected_page_size", | |||
| [ | |||
| ({"keywords": None}, 5), | |||
| ({"keywords": ""}, 5), | |||
| ({"keywords": "1"}, 1), | |||
| pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")), | |||
| ({"keywords": "content"}, 1), | |||
| ({"keywords": "unknown"}, 0), | |||
| ], | |||
| ) | |||
| def test_keywords(self, WebApiAuth, add_chunks, params, expected_page_size): | |||
| _, doc_id, _ = add_chunks | |||
| payload = {"doc_id": doc_id} | |||
| if params: | |||
| payload.update(params) | |||
| res = list_chunks(WebApiAuth, payload) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||
| @pytest.mark.p3 | |||
| def test_invalid_params(self, WebApiAuth, add_chunks): | |||
| _, doc_id, _ = add_chunks | |||
| payload = {"doc_id": doc_id, "a": "b"} | |||
| res = list_chunks(WebApiAuth, payload) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]["chunks"]) == 5, res | |||
| @pytest.mark.p3 | |||
| def test_concurrent_list(self, WebApiAuth, add_chunks): | |||
| _, doc_id, _ = add_chunks | |||
| count = 100 | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(list_chunks, WebApiAuth, {"doc_id": doc_id}) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(len(future.result()["data"]["chunks"]) == 5 for future in futures) | |||
| @pytest.mark.p1 | |||
| def test_default(self, WebApiAuth, add_document): | |||
| _, doc_id = add_document | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| chunks_count = res["data"]["doc"]["chunk_num"] | |||
| batch_add_chunks(WebApiAuth, doc_id, 31) | |||
| # issues/6487 | |||
| from time import sleep | |||
| sleep(3) | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| assert res["code"] == 0 | |||
| assert len(res["data"]["chunks"]) == 30 | |||
| assert res["data"]["doc"]["chunk_num"] == chunks_count + 31 | |||
| @@ -0,0 +1,308 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import retrieval_chunks | |||
| from configs import INVALID_API_TOKEN | |||
| from libs.auth import RAGFlowWebApiAuth | |||
| @pytest.mark.p1 | |||
| class TestAuthorization: | |||
| @pytest.mark.parametrize( | |||
| "invalid_auth, expected_code, expected_message", | |||
| [ | |||
| (None, 401, "<Unauthorized '401: Unauthorized'>"), | |||
| (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, invalid_auth, expected_code, expected_message): | |||
| res = retrieval_chunks(invalid_auth, {"kb_id": "dummy_kb_id", "question": "dummy question"}) | |||
| assert res["code"] == expected_code, res | |||
| assert res["message"] == expected_message, res | |||
| class TestChunksRetrieval: | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_page_size, expected_message", | |||
| [ | |||
| ({"question": "chunk", "kb_id": None}, 0, 4, ""), | |||
| ({"question": "chunk", "doc_ids": None}, 101, 0, "required argument are missing: kb_id; "), | |||
| ({"question": "chunk", "kb_id": None, "doc_ids": None}, 0, 4, ""), | |||
| ({"question": "chunk"}, 101, 0, "required argument are missing: kb_id; "), | |||
| ], | |||
| ) | |||
| def test_basic_scenarios(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||
| dataset_id, document_id, _ = add_chunks | |||
| if "kb_id" in payload: | |||
| payload["kb_id"] = [dataset_id] | |||
| if "doc_ids" in payload: | |||
| payload["doc_ids"] = [document_id] | |||
| res = retrieval_chunks(WebApiAuth, payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||
| else: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_page_size, expected_message", | |||
| [ | |||
| pytest.param( | |||
| {"page": None, "size": 2}, | |||
| 100, | |||
| 0, | |||
| """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| pytest.param( | |||
| {"page": 0, "size": 2}, | |||
| 100, | |||
| 0, | |||
| "ValueError('Search does not support negative slicing.')", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| pytest.param({"page": 2, "size": 2}, 0, 2, "", marks=pytest.mark.skip(reason="issues/6646")), | |||
| ({"page": 3, "size": 2}, 0, 0, ""), | |||
| ({"page": "3", "size": 2}, 0, 0, ""), | |||
| pytest.param( | |||
| {"page": -1, "size": 2}, | |||
| 100, | |||
| 0, | |||
| "ValueError('Search does not support negative slicing.')", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| pytest.param( | |||
| {"page": "a", "size": 2}, | |||
| 100, | |||
| 0, | |||
| """ValueError("invalid literal for int() with base 10: 'a'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| ], | |||
| ) | |||
| def test_page(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||
| dataset_id, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||
| res = retrieval_chunks(WebApiAuth, payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||
| else: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_page_size, expected_message", | |||
| [ | |||
| pytest.param( | |||
| {"size": None}, | |||
| 100, | |||
| 0, | |||
| """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| # ({"size": 0}, 0, 0, ""), | |||
| ({"size": 1}, 0, 1, ""), | |||
| ({"size": 5}, 0, 4, ""), | |||
| ({"size": "1"}, 0, 1, ""), | |||
| # ({"size": -1}, 0, 0, ""), | |||
| pytest.param( | |||
| {"size": "a"}, | |||
| 100, | |||
| 0, | |||
| """ValueError("invalid literal for int() with base 10: 'a'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| ], | |||
| ) | |||
| def test_page_size(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||
| dataset_id, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||
| res = retrieval_chunks(WebApiAuth, payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||
| else: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_page_size, expected_message", | |||
| [ | |||
| ({"vector_similarity_weight": 0}, 0, 4, ""), | |||
| ({"vector_similarity_weight": 0.5}, 0, 4, ""), | |||
| ({"vector_similarity_weight": 10}, 0, 4, ""), | |||
| pytest.param( | |||
| {"vector_similarity_weight": "a"}, | |||
| 100, | |||
| 0, | |||
| """ValueError("could not convert string to float: 'a'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| ], | |||
| ) | |||
| def test_vector_similarity_weight(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||
| dataset_id, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||
| res = retrieval_chunks(WebApiAuth, payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||
| else: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_page_size, expected_message", | |||
| [ | |||
| ({"top_k": 10}, 0, 4, ""), | |||
| pytest.param( | |||
| {"top_k": 1}, | |||
| 0, | |||
| 4, | |||
| "", | |||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), | |||
| ), | |||
| pytest.param( | |||
| {"top_k": 1}, | |||
| 0, | |||
| 1, | |||
| "", | |||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), | |||
| ), | |||
| pytest.param( | |||
| {"top_k": -1}, | |||
| 100, | |||
| 4, | |||
| "must be greater than 0", | |||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), | |||
| ), | |||
| pytest.param( | |||
| {"top_k": -1}, | |||
| 100, | |||
| 4, | |||
| "3014", | |||
| marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), | |||
| ), | |||
| pytest.param( | |||
| {"top_k": "a"}, | |||
| 100, | |||
| 0, | |||
| """ValueError("invalid literal for int() with base 10: 'a'")""", | |||
| marks=pytest.mark.skip, | |||
| ), | |||
| ], | |||
| ) | |||
| def test_top_k(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||
| dataset_id, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||
| res = retrieval_chunks(WebApiAuth, payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||
| else: | |||
| assert expected_message in res["message"], res | |||
| @pytest.mark.skip | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message", | |||
| [ | |||
| ({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""), | |||
| pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_rerank_id(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||
| dataset_id, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||
| res = retrieval_chunks(WebApiAuth, payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| assert len(res["data"]["chunks"]) > 0, res | |||
| else: | |||
| assert expected_message in res["message"], res | |||
| @pytest.mark.skip | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_page_size, expected_message", | |||
| [ | |||
| ({"keyword": True}, 0, 5, ""), | |||
| ({"keyword": "True"}, 0, 5, ""), | |||
| ({"keyword": False}, 0, 5, ""), | |||
| ({"keyword": "False"}, 0, 5, ""), | |||
| ({"keyword": None}, 0, 5, ""), | |||
| ], | |||
| ) | |||
| def test_keyword(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): | |||
| dataset_id, _, _ = add_chunks | |||
| payload.update({"question": "chunk test", "kb_id": [dataset_id]}) | |||
| res = retrieval_chunks(WebApiAuth, payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code == 0: | |||
| assert len(res["data"]["chunks"]) == expected_page_size, res | |||
| else: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_highlight, expected_message", | |||
| [ | |||
| ({"highlight": True}, 0, True, ""), | |||
| ({"highlight": "True"}, 0, True, ""), | |||
| pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), | |||
| pytest.param({"highlight": "False"}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), | |||
| pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), | |||
| ], | |||
| ) | |||
| def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message): | |||
| dataset_id, _, _ = add_chunks | |||
| payload.update({"question": "chunk", "kb_id": [dataset_id]}) | |||
| res = retrieval_chunks(WebApiAuth, payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_highlight: | |||
| for chunk in res["data"]["chunks"]: | |||
| assert "highlight" in chunk, res | |||
| else: | |||
| for chunk in res["data"]["chunks"]: | |||
| assert "highlight" not in chunk, res | |||
| if expected_code != 0: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p3 | |||
| def test_invalid_params(self, WebApiAuth, add_chunks): | |||
| dataset_id, _, _ = add_chunks | |||
| payload = {"question": "chunk", "kb_id": [dataset_id], "a": "b"} | |||
| res = retrieval_chunks(WebApiAuth, payload) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]["chunks"]) == 4, res | |||
| @pytest.mark.p3 | |||
| def test_concurrent_retrieval(self, WebApiAuth, add_chunks): | |||
| dataset_id, _, _ = add_chunks | |||
| count = 100 | |||
| payload = {"question": "chunk", "kb_id": [dataset_id]} | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [executor.submit(retrieval_chunks, WebApiAuth, payload) for i in range(count)] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @@ -0,0 +1,161 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| import pytest | |||
| from common import batch_add_chunks, delete_chunks, list_chunks | |||
| from configs import INVALID_API_TOKEN | |||
| from libs.auth import RAGFlowWebApiAuth | |||
| @pytest.mark.p1 | |||
| class TestAuthorization: | |||
| @pytest.mark.parametrize( | |||
| "invalid_auth, expected_code, expected_message", | |||
| [ | |||
| (None, 401, "<Unauthorized '401: Unauthorized'>"), | |||
| (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, invalid_auth, expected_code, expected_message): | |||
| res = delete_chunks(invalid_auth, {"doc_id": "document_id", "chunk_ids": ["1"]}) | |||
| assert res["code"] == expected_code | |||
| assert res["message"] == expected_message | |||
| class TestChunksDeletion: | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "doc_id, expected_code, expected_message", | |||
| [ | |||
| ("", 102, "Document not found!"), | |||
| ("invalid_document_id", 102, "Document not found!"), | |||
| ], | |||
| ) | |||
| def test_invalid_document_id(self, WebApiAuth, add_chunks_func, doc_id, expected_code, expected_message): | |||
| _, _, chunk_ids = add_chunks_func | |||
| res = delete_chunks(WebApiAuth, {"doc_id": doc_id, "chunk_ids": chunk_ids}) | |||
| assert res["code"] == expected_code, res | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.parametrize( | |||
| "payload", | |||
| [ | |||
| pytest.param(lambda r: {"chunk_ids": ["invalid_id"] + r}, marks=pytest.mark.p3), | |||
| pytest.param(lambda r: {"chunk_ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1), | |||
| pytest.param(lambda r: {"chunk_ids": r + ["invalid_id"]}, marks=pytest.mark.p3), | |||
| ], | |||
| ) | |||
| def test_delete_partial_invalid_id(self, WebApiAuth, add_chunks_func, payload): | |||
| _, doc_id, chunk_ids = add_chunks_func | |||
| if callable(payload): | |||
| payload = payload(chunk_ids) | |||
| payload["doc_id"] = doc_id | |||
| res = delete_chunks(WebApiAuth, payload) | |||
| assert res["code"] == 0, res | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]["chunks"]) == 0, res | |||
| assert res["data"]["total"] == 0, res | |||
| @pytest.mark.p3 | |||
| def test_repeated_deletion(self, WebApiAuth, add_chunks_func): | |||
| _, doc_id, chunk_ids = add_chunks_func | |||
| payload = {"chunk_ids": chunk_ids, "doc_id": doc_id} | |||
| res = delete_chunks(WebApiAuth, payload) | |||
| assert res["code"] == 0, res | |||
| res = delete_chunks(WebApiAuth, payload) | |||
| assert res["code"] == 102, res | |||
| assert res["message"] == "Index updating failure", res | |||
| @pytest.mark.p3 | |||
| def test_duplicate_deletion(self, WebApiAuth, add_chunks_func): | |||
| _, doc_id, chunk_ids = add_chunks_func | |||
| payload = {"chunk_ids": chunk_ids * 2, "doc_id": doc_id} | |||
| res = delete_chunks(WebApiAuth, payload) | |||
| assert res["code"] == 0, res | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| assert res["code"] == 0, res | |||
| assert len(res["data"]["chunks"]) == 0, res | |||
| assert res["data"]["total"] == 0, res | |||
| @pytest.mark.p3 | |||
| def test_concurrent_deletion(self, WebApiAuth, add_document): | |||
| count = 100 | |||
| _, doc_id = add_document | |||
| chunk_ids = batch_add_chunks(WebApiAuth, doc_id, count) | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [ | |||
| executor.submit( | |||
| delete_chunks, | |||
| WebApiAuth, | |||
| {"doc_id": doc_id, "chunk_ids": chunk_ids[i : i + 1]}, | |||
| ) | |||
| for i in range(count) | |||
| ] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @pytest.mark.p3 | |||
| def test_delete_1k(self, WebApiAuth, add_document): | |||
| chunks_num = 1_000 | |||
| _, doc_id = add_document | |||
| chunk_ids = batch_add_chunks(WebApiAuth, doc_id, chunks_num) | |||
| from time import sleep | |||
| sleep(1) | |||
| res = delete_chunks(WebApiAuth, {"doc_id": doc_id, "chunk_ids": chunk_ids}) | |||
| assert res["code"] == 0 | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| if res["code"] != 0: | |||
| assert False, res | |||
| assert len(res["data"]["chunks"]) == 0, res | |||
| assert res["data"]["total"] == 0, res | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message, remaining", | |||
| [ | |||
| pytest.param(None, 100, """TypeError("argument of type \'NoneType\' is not iterable")""", 5, marks=pytest.mark.skip), | |||
| pytest.param({"chunk_ids": ["invalid_id"]}, 102, "Index updating failure", 4, marks=pytest.mark.p3), | |||
| pytest.param("not json", 100, """UnboundLocalError("local variable \'duplicate_messages\' referenced before assignment")""", 5, marks=pytest.mark.skip(reason="pull/6376")), | |||
| pytest.param(lambda r: {"chunk_ids": r[:1]}, 0, "", 3, marks=pytest.mark.p3), | |||
| pytest.param(lambda r: {"chunk_ids": r}, 0, "", 0, marks=pytest.mark.p1), | |||
| pytest.param({"chunk_ids": []}, 0, "", 0, marks=pytest.mark.p3), | |||
| ], | |||
| ) | |||
| def test_basic_scenarios(self, WebApiAuth, add_chunks_func, payload, expected_code, expected_message, remaining): | |||
| _, doc_id, chunk_ids = add_chunks_func | |||
| if callable(payload): | |||
| payload = payload(chunk_ids) | |||
| payload["doc_id"] = doc_id | |||
| res = delete_chunks(WebApiAuth, payload) | |||
| assert res["code"] == expected_code, res | |||
| if res["code"] != 0: | |||
| assert res["message"] == expected_message, res | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| if res["code"] != 0: | |||
| assert False, res | |||
| assert len(res["data"]["chunks"]) == remaining, res | |||
| assert res["data"]["total"] == remaining, res | |||
| @@ -0,0 +1,232 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| from random import randint | |||
| from time import sleep | |||
| import pytest | |||
| from common import delete_document, list_chunks, update_chunk | |||
| from configs import INVALID_API_TOKEN | |||
| from libs.auth import RAGFlowWebApiAuth | |||
| @pytest.mark.p1 | |||
| class TestAuthorization: | |||
| @pytest.mark.parametrize( | |||
| "invalid_auth, expected_code, expected_message", | |||
| [ | |||
| (None, 401, "<Unauthorized '401: Unauthorized'>"), | |||
| (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"), | |||
| ], | |||
| ) | |||
| def test_invalid_auth(self, invalid_auth, expected_code, expected_message): | |||
| res = update_chunk(invalid_auth, {"doc_id": "doc_id", "chunk_id": "chunk_id", "content_with_weight": "test"}) | |||
| assert res["code"] == expected_code, res | |||
| assert res["message"] == expected_message, res | |||
| class TestUpdateChunk: | |||
| @pytest.mark.p1 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message", | |||
| [ | |||
| ({"content_with_weight": None}, 100, "TypeError('expected string or bytes-like object')"), | |||
| ({"content_with_weight": ""}, 0, ""), | |||
| ({"content_with_weight": 1}, 100, "TypeError('expected string or bytes-like object')"), | |||
| ({"content_with_weight": "update chunk"}, 0, ""), | |||
| ({"content_with_weight": " "}, 0, ""), | |||
| ({"content_with_weight": "\n!?。;!?\"'"}, 0, ""), | |||
| ], | |||
| ) | |||
| def test_content(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||
| _, doc_id, chunk_ids = add_chunks | |||
| chunk_id = chunk_ids[0] | |||
| update_payload = {"doc_id": doc_id, "chunk_id": chunk_id} | |||
| if payload: | |||
| update_payload.update(payload) | |||
| res = update_chunk(WebApiAuth, update_payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code != 0: | |||
| assert res["message"] == expected_message, res | |||
| else: | |||
| sleep(1) | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| for chunk in res["data"]["chunks"]: | |||
| if chunk["chunk_id"] == chunk_id: | |||
| assert chunk["content_with_weight"] == payload["content_with_weight"] | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message", | |||
| [ | |||
| ({"important_kwd": ["a", "b", "c"]}, 0, ""), | |||
| ({"important_kwd": [""]}, 0, ""), | |||
| ({"important_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), | |||
| ({"important_kwd": ["a", "a"]}, 0, ""), | |||
| ({"important_kwd": "abc"}, 102, "`important_kwd` should be a list"), | |||
| ({"important_kwd": 123}, 102, "`important_kwd` should be a list"), | |||
| ], | |||
| ) | |||
| def test_important_keywords(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||
| _, doc_id, chunk_ids = add_chunks | |||
| chunk_id = chunk_ids[0] | |||
| update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} # Add content_with_weight as it's required | |||
| if payload: | |||
| update_payload.update(payload) | |||
| res = update_chunk(WebApiAuth, update_payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code != 0: | |||
| assert res["message"] == expected_message, res | |||
| else: | |||
| sleep(1) | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| for chunk in res["data"]["chunks"]: | |||
| if chunk["chunk_id"] == chunk_id: | |||
| assert chunk["important_kwd"] == payload["important_kwd"] | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message", | |||
| [ | |||
| ({"question_kwd": ["a", "b", "c"]}, 0, ""), | |||
| ({"question_kwd": [""]}, 0, ""), | |||
| ({"question_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), | |||
| ({"question_kwd": ["a", "a"]}, 0, ""), | |||
| ({"question_kwd": "abc"}, 102, "`question_kwd` should be a list"), | |||
| ({"question_kwd": 123}, 102, "`question_kwd` should be a list"), | |||
| ], | |||
| ) | |||
| def test_questions(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||
| _, doc_id, chunk_ids = add_chunks | |||
| chunk_id = chunk_ids[0] | |||
| update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} # Add content_with_weight as it's required | |||
| if payload: | |||
| update_payload.update(payload) | |||
| res = update_chunk(WebApiAuth, update_payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code != 0: | |||
| assert res["message"] == expected_message, res | |||
| else: | |||
| sleep(1) | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| for chunk in res["data"]["chunks"]: | |||
| if chunk["chunk_id"] == chunk_id: | |||
| assert chunk["question_kwd"] == payload["question_kwd"] | |||
| @pytest.mark.p2 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message", | |||
| [ | |||
| ({"available_int": 1}, 0, ""), | |||
| ({"available_int": 0}, 0, ""), | |||
| ], | |||
| ) | |||
| def test_available(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||
| _, doc_id, chunk_ids = add_chunks | |||
| chunk_id = chunk_ids[0] | |||
| update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} | |||
| if payload: | |||
| update_payload.update(payload) | |||
| res = update_chunk(WebApiAuth, update_payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code != 0: | |||
| assert res["message"] == expected_message, res | |||
| else: | |||
| sleep(1) | |||
| res = list_chunks(WebApiAuth, {"doc_id": doc_id}) | |||
| for chunk in res["data"]["chunks"]: | |||
| if chunk["chunk_id"] == chunk_id: | |||
| assert chunk["available_int"] == payload["available_int"] | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "doc_id_param, expected_code, expected_message", | |||
| [ | |||
| ("", 102, "Tenant not found!"), | |||
| ("invalid_doc_id", 102, "Tenant not found!"), | |||
| ], | |||
| ) | |||
| def test_invalid_document_id_for_update(self, WebApiAuth, add_chunks, doc_id_param, expected_code, expected_message): | |||
| _, _, chunk_ids = add_chunks | |||
| chunk_id = chunk_ids[0] | |||
| payload = {"doc_id": doc_id_param, "chunk_id": chunk_id, "content_with_weight": "test content"} | |||
| res = update_chunk(WebApiAuth, payload) | |||
| assert res["code"] == expected_code | |||
| assert expected_message in res["message"] | |||
| @pytest.mark.p3 | |||
| def test_repeated_update_chunk(self, WebApiAuth, add_chunks): | |||
| _, doc_id, chunk_ids = add_chunks | |||
| payload1 = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "chunk test 1"} | |||
| res = update_chunk(WebApiAuth, payload1) | |||
| assert res["code"] == 0 | |||
| payload2 = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "chunk test 2"} | |||
| res = update_chunk(WebApiAuth, payload2) | |||
| assert res["code"] == 0 | |||
| @pytest.mark.p3 | |||
| @pytest.mark.parametrize( | |||
| "payload, expected_code, expected_message", | |||
| [ | |||
| ({"unknown_key": "unknown_value"}, 0, ""), | |||
| ({}, 0, ""), | |||
| pytest.param(None, 100, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", marks=pytest.mark.skip), | |||
| ], | |||
| ) | |||
| def test_invalid_params(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): | |||
| _, doc_id, chunk_ids = add_chunks | |||
| chunk_id = chunk_ids[0] | |||
| update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} | |||
| if payload is not None: | |||
| update_payload.update(payload) | |||
| res = update_chunk(WebApiAuth, update_payload) | |||
| assert res["code"] == expected_code, res | |||
| if expected_code != 0: | |||
| assert res["message"] == expected_message, res | |||
| @pytest.mark.p3 | |||
| @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554") | |||
| def test_concurrent_update_chunk(self, WebApiAuth, add_chunks): | |||
| count = 50 | |||
| _, doc_id, chunk_ids = add_chunks | |||
| with ThreadPoolExecutor(max_workers=5) as executor: | |||
| futures = [ | |||
| executor.submit( | |||
| update_chunk, | |||
| WebApiAuth, | |||
| {"doc_id": doc_id, "chunk_id": chunk_ids[randint(0, 3)], "content_with_weight": f"update chunk test {i}"}, | |||
| ) | |||
| for i in range(count) | |||
| ] | |||
| responses = list(as_completed(futures)) | |||
| assert len(responses) == count, responses | |||
| assert all(future.result()["code"] == 0 for future in futures) | |||
| @pytest.mark.p3 | |||
| def test_update_chunk_to_deleted_document(self, WebApiAuth, add_chunks): | |||
| _, doc_id, chunk_ids = add_chunks | |||
| delete_document(WebApiAuth, {"doc_id": doc_id}) | |||
| payload = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "test content"} | |||
| res = update_chunk(WebApiAuth, payload) | |||
| assert res["code"] == 102, res | |||
| assert res["message"] == "Tenant not found!", res | |||